mirror of
https://github.com/LonamiWebs/Telethon.git
synced 2025-08-06 05:00:23 +03:00
Merge branch 'master' into master
This commit is contained in:
commit
1c6f1ac148
24
run_tests.py
24
run_tests.py
|
@ -1,24 +0,0 @@
|
|||
#!/usr/bin/env python3
|
||||
import unittest
|
||||
|
||||
if __name__ == '__main__':
|
||||
from telethon_tests import \
|
||||
CryptoTests, ParserTests, TLTests, UtilsTests, NetworkTests
|
||||
|
||||
test_classes = [CryptoTests, ParserTests, TLTests, UtilsTests]
|
||||
|
||||
network = input('Run network tests (y/n)?: ').lower() == 'y'
|
||||
if network:
|
||||
test_classes.append(NetworkTests)
|
||||
|
||||
loader = unittest.TestLoader()
|
||||
|
||||
suites_list = []
|
||||
for test_class in test_classes:
|
||||
suite = loader.loadTestsFromTestCase(test_class)
|
||||
suites_list.append(suite)
|
||||
|
||||
big_suite = unittest.TestSuite(suites_list)
|
||||
|
||||
runner = unittest.TextTestRunner()
|
||||
results = runner.run(big_suite)
|
|
@ -2,7 +2,7 @@ import logging
|
|||
from .client.telegramclient import TelegramClient
|
||||
from .network import connection
|
||||
from .tl import types, functions, custom
|
||||
from . import version, events, utils, errors
|
||||
from . import version, events, utils, errors, full_sync
|
||||
|
||||
|
||||
__version__ = version.__version__
|
||||
|
|
|
@ -2,8 +2,8 @@ import itertools
|
|||
import re
|
||||
|
||||
from .users import UserMethods
|
||||
from .. import utils
|
||||
from ..tl import types, custom
|
||||
from .. import default, utils
|
||||
from ..tl import types
|
||||
|
||||
|
||||
class MessageParseMethods(UserMethods):
|
||||
|
@ -62,7 +62,7 @@ class MessageParseMethods(UserMethods):
|
|||
"""
|
||||
Returns a (parsed message, entities) tuple depending on ``parse_mode``.
|
||||
"""
|
||||
if parse_mode == utils.Default:
|
||||
if parse_mode == default:
|
||||
parse_mode = self._parse_mode
|
||||
else:
|
||||
parse_mode = utils.sanitize_parse_mode(parse_mode)
|
||||
|
|
|
@ -8,8 +8,8 @@ from async_generator import async_generator, yield_
|
|||
from .messageparse import MessageParseMethods
|
||||
from .uploads import UploadMethods
|
||||
from .buttons import ButtonMethods
|
||||
from .. import utils, helpers
|
||||
from ..tl import types, functions, custom
|
||||
from .. import default, helpers, utils
|
||||
from ..tl import types, functions
|
||||
|
||||
__log__ = logging.getLogger(__name__)
|
||||
|
||||
|
@ -360,7 +360,7 @@ class MessageMethods(UploadMethods, ButtonMethods, MessageParseMethods):
|
|||
|
||||
async def send_message(
|
||||
self, entity, message='', *, reply_to=None,
|
||||
parse_mode=utils.Default, link_preview=True, file=None,
|
||||
parse_mode=default, link_preview=True, file=None,
|
||||
force_document=False, clear_draft=False, buttons=None,
|
||||
silent=None):
|
||||
"""
|
||||
|
@ -584,7 +584,7 @@ class MessageMethods(UploadMethods, ButtonMethods, MessageParseMethods):
|
|||
|
||||
async def edit_message(
|
||||
self, entity, message=None, text=None,
|
||||
*, parse_mode=utils.Default, link_preview=True, file=None,
|
||||
*, parse_mode=default, link_preview=True, file=None,
|
||||
buttons=None):
|
||||
"""
|
||||
Edits the given message ID (to change its contents or disable preview).
|
||||
|
|
|
@ -1,6 +1,5 @@
|
|||
import abc
|
||||
import asyncio
|
||||
import collections
|
||||
import inspect
|
||||
import logging
|
||||
import platform
|
||||
|
@ -12,7 +11,6 @@ from .. import version
|
|||
from ..crypto import rsa
|
||||
from ..extensions import markdown
|
||||
from ..network import MTProtoSender, ConnectionTcpFull
|
||||
from ..network.mtprotostate import MTProtoState
|
||||
from ..sessions import Session, SQLiteSession, MemorySession
|
||||
from ..tl import TLObject, functions, types
|
||||
from ..tl.alltlobjects import LAYER
|
||||
|
@ -55,7 +53,7 @@ class TelegramBaseClient(abc.ABC):
|
|||
|
||||
connection (`telethon.network.connection.common.Connection`, optional):
|
||||
The connection instance to be used when creating a new connection
|
||||
to the servers. If it's a type, the `proxy` argument will be used.
|
||||
to the servers. It **must** be a type.
|
||||
|
||||
Defaults to `telethon.network.connection.tcpfull.ConnectionTcpFull`.
|
||||
|
||||
|
@ -68,11 +66,11 @@ class TelegramBaseClient(abc.ABC):
|
|||
A tuple consisting of ``(socks.SOCKS5, 'host', port)``.
|
||||
See https://github.com/Anorov/PySocks#usage-1 for more.
|
||||
|
||||
timeout (`int` | `float` | `timedelta`, optional):
|
||||
The timeout to be used when connecting, sending and receiving
|
||||
responses from the network. This is **not** the timeout to
|
||||
be used when ``await``'ing for invoked requests, and you
|
||||
should use ``asyncio.wait`` or ``asyncio.wait_for`` for that.
|
||||
timeout (`int` | `float`, optional):
|
||||
The timeout in seconds to be used when connecting.
|
||||
This is **not** the timeout to be used when ``await``'ing for
|
||||
invoked requests, and you should use ``asyncio.wait`` or
|
||||
``asyncio.wait_for`` for that.
|
||||
|
||||
request_retries (`int`, optional):
|
||||
How many times a request should be retried. Request are retried
|
||||
|
@ -150,7 +148,7 @@ class TelegramBaseClient(abc.ABC):
|
|||
connection=ConnectionTcpFull,
|
||||
use_ipv6=False,
|
||||
proxy=None,
|
||||
timeout=timedelta(seconds=10),
|
||||
timeout=10,
|
||||
request_retries=5,
|
||||
connection_retries=5,
|
||||
auto_reconnect=True,
|
||||
|
@ -205,11 +203,12 @@ class TelegramBaseClient(abc.ABC):
|
|||
|
||||
self._request_retries = request_retries or sys.maxsize
|
||||
self._connection_retries = connection_retries or sys.maxsize
|
||||
self._proxy = proxy
|
||||
self._timeout = timeout
|
||||
self._auto_reconnect = auto_reconnect
|
||||
|
||||
if isinstance(connection, type):
|
||||
connection = connection(
|
||||
proxy=proxy, timeout=timeout, loop=self._loop)
|
||||
assert isinstance(connection, type)
|
||||
self._connection = connection
|
||||
|
||||
# Used on connection. Capture the variables in a lambda since
|
||||
# exporting clients need to create this InvokeWithLayerRequest.
|
||||
|
@ -227,12 +226,12 @@ class TelegramBaseClient(abc.ABC):
|
|||
)
|
||||
)
|
||||
|
||||
state = MTProtoState(self.session.auth_key)
|
||||
self._connection = connection
|
||||
self._sender = MTProtoSender(
|
||||
state, connection, self._loop,
|
||||
self._loop,
|
||||
retries=self._connection_retries,
|
||||
auto_reconnect=self._auto_reconnect,
|
||||
connect_timeout=self._timeout,
|
||||
update_callback=self._handle_update,
|
||||
auth_key_callback=self._auth_key_callback,
|
||||
auto_reconnect_callback=self._handle_auto_reconnect
|
||||
|
@ -250,10 +249,6 @@ class TelegramBaseClient(abc.ABC):
|
|||
# Save whether the user is authorized here (a.k.a. logged in)
|
||||
self._authorized = None # None = We don't know yet
|
||||
|
||||
# Default PingRequest delay
|
||||
self._last_ping = datetime.now()
|
||||
self._ping_delay = timedelta(minutes=1)
|
||||
|
||||
self._updates_handle = None
|
||||
self._last_request = time.time()
|
||||
self._channel_pts = {}
|
||||
|
@ -309,8 +304,10 @@ class TelegramBaseClient(abc.ABC):
|
|||
"""
|
||||
Connects to Telegram.
|
||||
"""
|
||||
await self._sender.connect(
|
||||
self.session.server_address, self.session.port)
|
||||
await self._sender.connect(self.session.auth_key, self._connection(
|
||||
self.session.server_address, self.session.port,
|
||||
loop=self._loop, proxy=self._proxy
|
||||
))
|
||||
|
||||
await self._sender.send(self._init_with(
|
||||
functions.help.GetConfigRequest()))
|
||||
|
@ -373,7 +370,7 @@ class TelegramBaseClient(abc.ABC):
|
|||
await self.session.set_dc(dc.id, dc.ip_address, dc.port)
|
||||
# auth_key's are associated with a server, which has now changed
|
||||
# so it's not valid anymore. Set to None to force recreating it.
|
||||
self.session.auth_key = self._sender.state.auth_key = None
|
||||
self.session.auth_key = None
|
||||
await self.session.save()
|
||||
await self._disconnect()
|
||||
return await self.connect()
|
||||
|
@ -416,13 +413,13 @@ class TelegramBaseClient(abc.ABC):
|
|||
# Thanks badoualy/kotlogram on /telegram/api/DefaultTelegramClient.kt
|
||||
# for clearly showing how to export the authorization
|
||||
dc = await self._get_dc(dc_id)
|
||||
state = MTProtoState(None)
|
||||
# Can't reuse self._sender._connection as it has its own seqno.
|
||||
#
|
||||
# If one were to do that, Telegram would reset the connection
|
||||
# with no further clues.
|
||||
sender = MTProtoSender(state, self._connection.clone(), self._loop)
|
||||
await sender.connect(dc.ip_address, dc.port)
|
||||
sender = MTProtoSender(self._loop)
|
||||
await sender.connect(None, self._connection(
|
||||
dc.ip_address, dc.port, loop=self._loop, proxy=self._proxy))
|
||||
__log__.info('Exporting authorization for data center %s', dc)
|
||||
auth = await self(functions.auth.ExportAuthorizationRequest(dc_id))
|
||||
req = self._init_with(functions.auth.ImportAuthorizationRequest(
|
||||
|
|
|
@ -5,11 +5,11 @@ import os
|
|||
import pathlib
|
||||
import re
|
||||
from io import BytesIO
|
||||
from mimetypes import guess_type
|
||||
|
||||
from .buttons import ButtonMethods
|
||||
from .messageparse import MessageParseMethods
|
||||
from .users import UserMethods
|
||||
from .buttons import ButtonMethods
|
||||
from .. import default
|
||||
from .. import utils, helpers
|
||||
from ..tl import types, functions, custom
|
||||
|
||||
|
@ -23,7 +23,7 @@ class UploadMethods(ButtonMethods, MessageParseMethods, UserMethods):
|
|||
async def send_file(
|
||||
self, entity, file, *, caption='', force_document=False,
|
||||
progress_callback=None, reply_to=None, attributes=None,
|
||||
thumb=None, allow_cache=True, parse_mode=utils.Default,
|
||||
thumb=None, allow_cache=True, parse_mode=default,
|
||||
voice_note=False, video_note=False, buttons=None, silent=None,
|
||||
**kwargs):
|
||||
"""
|
||||
|
@ -180,7 +180,7 @@ class UploadMethods(ButtonMethods, MessageParseMethods, UserMethods):
|
|||
|
||||
async def _send_album(self, entity, files, caption='',
|
||||
progress_callback=None, reply_to=None,
|
||||
parse_mode=utils.Default, silent=None):
|
||||
parse_mode=default, silent=None):
|
||||
"""Specialized version of .send_file for albums"""
|
||||
# We don't care if the user wants to avoid cache, we will use it
|
||||
# anyway. Why? The cached version will be exactly the same thing
|
||||
|
|
5
telethon/default.py
Normal file
5
telethon/default.py
Normal file
|
@ -0,0 +1,5 @@
|
|||
"""
|
||||
Sentinel module to signify that a parameter should use its default value.
|
||||
|
||||
Useful when the default value or ``None`` are both valid options.
|
||||
"""
|
|
@ -124,7 +124,7 @@ class InlineQuery(EventBuilder):
|
|||
|
||||
async def answer(
|
||||
self, results=None, cache_time=0, *,
|
||||
gallery=False, private=False,
|
||||
gallery=False, next_offset=None, private=False,
|
||||
switch_pm=None, switch_pm_param=''):
|
||||
"""
|
||||
Answers the inline query with the given results.
|
||||
|
@ -147,6 +147,10 @@ class InlineQuery(EventBuilder):
|
|||
|
||||
gallery (`bool`, optional):
|
||||
Whether the results should show as a gallery (grid) or not.
|
||||
|
||||
next_offset (`str`, optional):
|
||||
The offset the client will send when the user scrolls the
|
||||
results and it repeats the request.
|
||||
|
||||
private (`bool`, optional):
|
||||
Whether the results should be cached by Telegram
|
||||
|
@ -163,11 +167,14 @@ class InlineQuery(EventBuilder):
|
|||
if self._answered:
|
||||
return
|
||||
|
||||
results = [self._as_awaitable(x, self._client.loop)
|
||||
for x in results]
|
||||
if results:
|
||||
results = [self._as_awaitable(x, self._client.loop)
|
||||
for x in results]
|
||||
|
||||
done, _ = await asyncio.wait(results, loop=self._client.loop)
|
||||
results = [x.result() for x in done]
|
||||
done, _ = await asyncio.wait(results, loop=self._client.loop)
|
||||
results = [x.result() for x in done]
|
||||
else:
|
||||
results = []
|
||||
|
||||
if switch_pm:
|
||||
switch_pm = types.InlineBotSwitchPM(switch_pm, switch_pm_param)
|
||||
|
@ -178,6 +185,7 @@ class InlineQuery(EventBuilder):
|
|||
results=results,
|
||||
cache_time=cache_time,
|
||||
gallery=gallery,
|
||||
next_offset=next_offset,
|
||||
private=private,
|
||||
switch_pm=switch_pm
|
||||
)
|
||||
|
|
|
@ -4,4 +4,3 @@ communication with support for cancelling the operation, and an utility class
|
|||
to read arbitrary binary data in a more comfortable way, with int/strings/etc.
|
||||
"""
|
||||
from .binaryreader import BinaryReader
|
||||
from .tcpclient import TcpClient
|
||||
|
|
|
@ -9,8 +9,8 @@ from html.parser import HTMLParser
|
|||
from ..tl.types import (
|
||||
MessageEntityBold, MessageEntityItalic, MessageEntityCode,
|
||||
MessageEntityPre, MessageEntityEmail, MessageEntityUrl,
|
||||
MessageEntityTextUrl
|
||||
)
|
||||
MessageEntityTextUrl, MessageEntityMentionName
|
||||
)
|
||||
|
||||
|
||||
# Helpers from markdown.py
|
||||
|
@ -178,6 +178,9 @@ def unparse(text, entities):
|
|||
elif entity_type == MessageEntityTextUrl:
|
||||
html.append('<a href="{}">{}</a>'
|
||||
.format(escape(entity.url), entity_text))
|
||||
elif entity_type == MessageEntityMentionName:
|
||||
html.append('<a href="tg://user?id={}">{}</a>'
|
||||
.format(entity.user_id, entity_text))
|
||||
else:
|
||||
skip_entity = True
|
||||
last_offset = entity.offset + (0 if skip_entity else entity.length)
|
||||
|
|
|
@ -9,8 +9,8 @@ from ..helpers import add_surrogate, del_surrogate
|
|||
from ..tl import TLObject
|
||||
from ..tl.types import (
|
||||
MessageEntityBold, MessageEntityItalic, MessageEntityCode,
|
||||
MessageEntityPre, MessageEntityTextUrl
|
||||
)
|
||||
MessageEntityPre, MessageEntityTextUrl, MessageEntityMentionName
|
||||
)
|
||||
|
||||
DEFAULT_DELIMITERS = {
|
||||
'**': MessageEntityBold,
|
||||
|
@ -161,11 +161,17 @@ def unparse(text, entities, delimiters=None, url_fmt=None):
|
|||
delimiter = delimiters.get(type(entity), None)
|
||||
if delimiter:
|
||||
text = text[:s] + delimiter + text[s:e] + delimiter + text[e:]
|
||||
elif isinstance(entity, MessageEntityTextUrl) and url_fmt:
|
||||
text = (
|
||||
text[:s] +
|
||||
add_surrogate(url_fmt.format(text[s:e], entity.url)) +
|
||||
text[e:]
|
||||
)
|
||||
elif url_fmt:
|
||||
url = None
|
||||
if isinstance(entity, MessageEntityTextUrl):
|
||||
url = entity.url
|
||||
elif isinstance(entity, MessageEntityMentionName):
|
||||
url = 'tg://user?id={}'.format(entity.user_id)
|
||||
if url:
|
||||
text = (
|
||||
text[:s] +
|
||||
add_surrogate(url_fmt.format(text[s:e], url)) +
|
||||
text[e:]
|
||||
)
|
||||
|
||||
return del_surrogate(text)
|
||||
|
|
|
@ -1,171 +0,0 @@
|
|||
"""
|
||||
This module holds a rough implementation of the C# TCP client.
|
||||
|
||||
This class is **not** safe across several tasks since partial reads
|
||||
may be ``await``'ed before being able to return the exact byte count.
|
||||
|
||||
This class is also not concerned about disconnections or retries of
|
||||
any sort, nor any other kind of errors such as connecting twice.
|
||||
"""
|
||||
import asyncio
|
||||
import errno
|
||||
import logging
|
||||
import socket
|
||||
import ssl
|
||||
|
||||
CONN_RESET_ERRNOS = {
|
||||
errno.EBADF, errno.ENOTSOCK, errno.ENETUNREACH,
|
||||
errno.EINVAL, errno.ENOTCONN, errno.EHOSTUNREACH,
|
||||
errno.ECONNREFUSED, errno.ECONNRESET, errno.ECONNABORTED,
|
||||
errno.ENETDOWN, errno.ENETRESET, errno.ECONNABORTED,
|
||||
errno.EHOSTDOWN, errno.EPIPE, errno.ESHUTDOWN
|
||||
}
|
||||
# catched: EHOSTUNREACH, ECONNREFUSED, ECONNRESET, ENETUNREACH
|
||||
# ConnectionError: EPIPE, ESHUTDOWN, ECONNABORTED, ECONNREFUSED, ECONNRESET
|
||||
|
||||
try:
|
||||
import socks
|
||||
except ImportError:
|
||||
socks = None
|
||||
|
||||
SSL_PORT = 443
|
||||
__log__ = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TcpClient:
|
||||
"""A simple TCP client to ease the work with sockets and proxies."""
|
||||
|
||||
class SocketClosed(ConnectionError):
|
||||
pass
|
||||
|
||||
def __init__(self, *, loop, timeout, ssl=None, proxy=None):
|
||||
"""
|
||||
Initializes the TCP client.
|
||||
|
||||
:param proxy: the proxy to be used, if any.
|
||||
:param timeout: the timeout for connect, read and write operations.
|
||||
:param ssl: ssl.wrap_socket keyword arguments to use when connecting
|
||||
if port == SSL_PORT, or do nothing if not present.
|
||||
"""
|
||||
self._loop = loop
|
||||
self.proxy = proxy
|
||||
self.ssl = ssl
|
||||
self._socket = None
|
||||
self._reader = None
|
||||
self._writer = None
|
||||
self._closed = asyncio.Event(loop=self._loop)
|
||||
self._closed.set()
|
||||
|
||||
if isinstance(timeout, (int, float)):
|
||||
self.timeout = float(timeout)
|
||||
elif hasattr(timeout, 'seconds'):
|
||||
self.timeout = float(timeout.seconds)
|
||||
else:
|
||||
raise TypeError('Invalid timeout type: {}'.format(type(timeout)))
|
||||
|
||||
@staticmethod
|
||||
def _create_socket(mode, proxy):
|
||||
if proxy is None:
|
||||
s = socket.socket(mode, socket.SOCK_STREAM)
|
||||
else:
|
||||
__log__.info('Connection will be made through proxy %s', proxy)
|
||||
import socks
|
||||
s = socks.socksocket(mode, socket.SOCK_STREAM)
|
||||
if isinstance(proxy, dict):
|
||||
s.set_proxy(**proxy)
|
||||
else: # tuple, list, etc.
|
||||
s.set_proxy(*proxy)
|
||||
s.setblocking(False)
|
||||
return s
|
||||
|
||||
async def connect(self, ip, port):
|
||||
"""
|
||||
Tries connecting to IP:port unless an OSError is raised.
|
||||
|
||||
:param ip: the IP to connect to.
|
||||
:param port: the port to connect to.
|
||||
"""
|
||||
if ':' in ip: # IPv6
|
||||
ip = ip.replace('[', '').replace(']', '')
|
||||
mode, address = socket.AF_INET6, (ip, port, 0, 0)
|
||||
else:
|
||||
mode, address = socket.AF_INET, (ip, port)
|
||||
|
||||
try:
|
||||
if self._socket is None:
|
||||
self._socket = self._create_socket(mode, self.proxy)
|
||||
wrap_ssl = self.ssl and port == SSL_PORT
|
||||
else:
|
||||
wrap_ssl = False
|
||||
|
||||
await asyncio.wait_for(
|
||||
self._loop.sock_connect(self._socket, address),
|
||||
timeout=self.timeout,
|
||||
loop=self._loop
|
||||
)
|
||||
if wrap_ssl:
|
||||
# Temporarily set the socket to blocking
|
||||
# (timeout) until connection is established.
|
||||
self._socket.settimeout(self.timeout)
|
||||
self._socket = ssl.wrap_socket(
|
||||
self._socket, do_handshake_on_connect=True, **self.ssl)
|
||||
self._socket.setblocking(False)
|
||||
|
||||
self._closed.clear()
|
||||
self._reader, self._writer =\
|
||||
await asyncio.open_connection(sock=self._socket)
|
||||
except OSError as e:
|
||||
if e.errno in CONN_RESET_ERRNOS:
|
||||
raise ConnectionResetError() from e
|
||||
else:
|
||||
raise
|
||||
|
||||
@property
|
||||
def is_connected(self):
|
||||
"""Determines whether the client is connected or not."""
|
||||
return not self._closed.is_set()
|
||||
|
||||
def close(self):
|
||||
"""Closes the connection."""
|
||||
fd = None
|
||||
try:
|
||||
if self._writer is not None:
|
||||
self._writer.close()
|
||||
|
||||
if self._socket is not None:
|
||||
fd = self._socket.fileno()
|
||||
if self.is_connected:
|
||||
self._socket.shutdown(socket.SHUT_RDWR)
|
||||
self._socket.close()
|
||||
except OSError:
|
||||
pass # Ignore ENOTCONN, EBADF, and any other error when closing
|
||||
finally:
|
||||
self._socket = None
|
||||
self._reader = None
|
||||
self._writer = None
|
||||
self._closed.set()
|
||||
if fd and fd != -1:
|
||||
self._loop.remove_reader(fd)
|
||||
|
||||
async def write(self, data):
|
||||
"""
|
||||
Writes (sends) the specified bytes to the connected peer.
|
||||
:param data: the data to send.
|
||||
"""
|
||||
if not self.is_connected:
|
||||
raise ConnectionResetError('Not connected')
|
||||
|
||||
self._writer.write(data)
|
||||
await self._writer.drain()
|
||||
|
||||
async def read(self, size):
|
||||
"""
|
||||
Reads (receives) a whole block of size bytes from the connected peer.
|
||||
|
||||
:param size: the size of the block to be read.
|
||||
:return: the read data with len(data) == size.
|
||||
"""
|
||||
if not self.is_connected:
|
||||
raise ConnectionResetError('Not connected')
|
||||
|
||||
return await self._reader.readexactly(size)
|
161
telethon/full_sync.py
Normal file
161
telethon/full_sync.py
Normal file
|
@ -0,0 +1,161 @@
|
|||
"""
|
||||
This magical module will rewrite all public methods in the public interface of
|
||||
the library so they can delegate the call to an asyncio event loop in another
|
||||
thread and wait for the result. This rewrite may not be desirable if the end
|
||||
user always uses the methods they way they should be ran, but it's incredibly
|
||||
useful for quick scripts and legacy code.
|
||||
"""
|
||||
import asyncio
|
||||
import functools
|
||||
import inspect
|
||||
import threading
|
||||
from concurrent.futures import Future, ThreadPoolExecutor
|
||||
|
||||
from async_generator import isasyncgenfunction
|
||||
|
||||
from .client.telegramclient import TelegramClient
|
||||
from .tl.custom import (
|
||||
Draft, Dialog, MessageButton, Forward, Message, InlineResult, Conversation
|
||||
)
|
||||
from .tl.custom.chatgetter import ChatGetter
|
||||
from .tl.custom.sendergetter import SenderGetter
|
||||
|
||||
|
||||
async def _proxy_future(af, cf):
|
||||
try:
|
||||
res = await af
|
||||
cf.set_result(res)
|
||||
except Exception as e:
|
||||
cf.set_exception(e)
|
||||
|
||||
|
||||
def _sync_result(loop, x):
|
||||
f = Future()
|
||||
loop.call_soon_threadsafe(asyncio.ensure_future, _proxy_future(x, f))
|
||||
return f.result()
|
||||
|
||||
|
||||
class _SyncGen:
|
||||
def __init__(self, loop, gen):
|
||||
self.loop = loop
|
||||
self.gen = gen
|
||||
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
||||
def __next__(self):
|
||||
try:
|
||||
return _sync_result(self.loop, self.gen.__anext__())
|
||||
except StopAsyncIteration:
|
||||
raise StopIteration from None
|
||||
|
||||
|
||||
def _syncify_wrap(t, method_name, loop, thread_ident, syncifier=_sync_result):
|
||||
method = getattr(t, method_name)
|
||||
|
||||
@functools.wraps(method)
|
||||
def syncified(*args, **kwargs):
|
||||
coro = method(*args, **kwargs)
|
||||
return (
|
||||
coro if threading.get_ident() == thread_ident
|
||||
else syncifier(loop, coro)
|
||||
)
|
||||
|
||||
setattr(t, method_name, syncified)
|
||||
|
||||
|
||||
def _syncify(*types, loop, thread_ident):
|
||||
for t in types:
|
||||
for method_name in dir(t):
|
||||
if not method_name.startswith('_') or method_name == '__call__':
|
||||
if inspect.iscoroutinefunction(getattr(t, method_name)):
|
||||
_syncify_wrap(t, method_name, loop, thread_ident, _sync_result)
|
||||
elif isasyncgenfunction(getattr(t, method_name)):
|
||||
_syncify_wrap(t, method_name, loop, thread_ident, _SyncGen)
|
||||
|
||||
|
||||
__asyncthread = None
|
||||
|
||||
|
||||
def enable(*, loop=None, executor=None, max_workers=1):
|
||||
"""
|
||||
Enables the fully synchronous mode. You should enable this at
|
||||
the beginning of your script, right after importing, only once.
|
||||
|
||||
**Please** make sure to call `stop` at the end of your script.
|
||||
|
||||
You can define the event loop to use and executor, otherwise
|
||||
the default loop and ``ThreadPoolExecutor`` will be used, in
|
||||
which case `max_workers` will be passed to it. If you pass a
|
||||
custom executor, `max_workers` will be ignored.
|
||||
"""
|
||||
global __asyncthread
|
||||
if __asyncthread is not None:
|
||||
raise RuntimeError("full_sync can only be enabled once")
|
||||
|
||||
if not loop:
|
||||
loop = asyncio.get_event_loop()
|
||||
if not executor:
|
||||
executor = ThreadPoolExecutor(max_workers=max_workers)
|
||||
|
||||
def start():
|
||||
asyncio.set_event_loop(loop)
|
||||
loop.run_forever()
|
||||
|
||||
__asyncthread = threading.Thread(
|
||||
target=start, name="__telethon_async_thread__", daemon=True
|
||||
)
|
||||
__asyncthread.start()
|
||||
__asyncthread.loop = loop
|
||||
__asyncthread.executor = executor
|
||||
|
||||
TelegramClient.__init__ = functools.partialmethod(
|
||||
TelegramClient.__init__, loop=loop
|
||||
)
|
||||
|
||||
_syncify(TelegramClient, Draft, Dialog, MessageButton, ChatGetter,
|
||||
SenderGetter, Forward, Message, InlineResult, Conversation,
|
||||
loop=loop, thread_ident=__asyncthread.ident)
|
||||
_syncify_wrap(TelegramClient, "start", loop, __asyncthread.ident)
|
||||
|
||||
old_add_event_handler = TelegramClient.add_event_handler
|
||||
old_remove_event_handler = TelegramClient.remove_event_handler
|
||||
proxied_event_handlers = {}
|
||||
|
||||
@functools.wraps(old_add_event_handler)
|
||||
def add_proxied_event_handler(self, callback, *args, **kwargs):
|
||||
async def _proxy(*pargs, **pkwargs):
|
||||
await loop.run_in_executor(
|
||||
executor, functools.partial(callback, *pargs, **pkwargs))
|
||||
|
||||
proxied_event_handlers[callback] = _proxy
|
||||
|
||||
args = (self, _proxy, *args)
|
||||
return old_add_event_handler(*args, **kwargs)
|
||||
|
||||
@functools.wraps(old_remove_event_handler)
|
||||
def remove_proxied_event_handler(self, callback, *args, **kwargs):
|
||||
args = (self, proxied_event_handlers.get(callback, callback), *args)
|
||||
return old_remove_event_handler(*args, **kwargs)
|
||||
|
||||
TelegramClient.add_event_handler = add_proxied_event_handler
|
||||
TelegramClient.remove_event_handler = remove_proxied_event_handler
|
||||
|
||||
def run_until_disconnected(self):
|
||||
return _sync_result(loop, self._run_until_disconnected())
|
||||
|
||||
TelegramClient.run_until_disconnected = run_until_disconnected
|
||||
|
||||
return __asyncthread
|
||||
|
||||
|
||||
def stop():
|
||||
"""
|
||||
Stops the fully synchronous code. You
|
||||
should call this before your script exits.
|
||||
"""
|
||||
global __asyncthread
|
||||
if not __asyncthread:
|
||||
raise RuntimeError("Can't find asyncio thread")
|
||||
__asyncthread.loop.call_soon_threadsafe(__asyncthread.loop.stop)
|
||||
__asyncthread.executor.shutdown()
|
|
@ -1,5 +1,5 @@
|
|||
"""Various helpers not related to the Telegram API itself"""
|
||||
import collections
|
||||
import asyncio
|
||||
import os
|
||||
import struct
|
||||
from hashlib import sha1, sha256
|
||||
|
@ -87,4 +87,46 @@ class TotalList(list):
|
|||
return '[{}, total={}]'.format(
|
||||
', '.join(repr(x) for x in self), self.total)
|
||||
|
||||
|
||||
class _ReadyQueue:
|
||||
"""
|
||||
A queue list that supports an arbitrary cancellation token for `get`.
|
||||
"""
|
||||
def __init__(self, loop):
|
||||
self._list = []
|
||||
self._loop = loop
|
||||
self._ready = asyncio.Event(loop=loop)
|
||||
|
||||
def append(self, item):
|
||||
self._list.append(item)
|
||||
self._ready.set()
|
||||
|
||||
def extend(self, items):
|
||||
self._list.extend(items)
|
||||
self._ready.set()
|
||||
|
||||
async def get(self, cancellation):
|
||||
"""
|
||||
Returns a list of all the items added to the queue until now and
|
||||
clears the list from the queue itself. Returns ``None`` if cancelled.
|
||||
"""
|
||||
ready = self._loop.create_task(self._ready.wait())
|
||||
try:
|
||||
done, pending = await asyncio.wait(
|
||||
[ready, cancellation],
|
||||
return_when=asyncio.FIRST_COMPLETED,
|
||||
loop=self._loop
|
||||
)
|
||||
except asyncio.CancelledError:
|
||||
done = [cancellation]
|
||||
|
||||
if cancellation in done:
|
||||
ready.cancel()
|
||||
return None
|
||||
|
||||
result = self._list
|
||||
self._list = []
|
||||
self._ready.clear()
|
||||
return result
|
||||
|
||||
# endregion
|
||||
|
|
|
@ -1,74 +0,0 @@
|
|||
"""
|
||||
This module holds the abstract `Connection` class.
|
||||
|
||||
The `Connection.send` and `Connection.recv` methods need **not** to be
|
||||
safe across several tasks and may use any amount of ``await`` keywords.
|
||||
|
||||
The code using these `Connection`'s should be responsible for using
|
||||
an ``async with asyncio.Lock:`` block when calling said methods.
|
||||
|
||||
Said subclasses need not to worry about reconnecting either, and
|
||||
should let the errors propagate instead.
|
||||
"""
|
||||
import abc
|
||||
|
||||
|
||||
class Connection(abc.ABC):
|
||||
"""
|
||||
Represents an abstract connection for Telegram.
|
||||
|
||||
Subclasses should implement the actual protocol
|
||||
being used when encoding/decoding messages.
|
||||
"""
|
||||
def __init__(self, *, loop, timeout, proxy=None):
|
||||
"""
|
||||
Initializes a new connection.
|
||||
|
||||
:param loop: the event loop to be used.
|
||||
:param timeout: timeout to be used for all operations.
|
||||
:param proxy: whether to use a proxy or not.
|
||||
"""
|
||||
self._loop = loop
|
||||
self._proxy = proxy
|
||||
self._timeout = timeout
|
||||
|
||||
@abc.abstractmethod
|
||||
async def connect(self, ip, port):
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_timeout(self):
|
||||
"""Returns the timeout used by the connection."""
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def is_connected(self):
|
||||
"""
|
||||
Determines whether the connection is alive or not.
|
||||
|
||||
:return: true if it's connected.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
async def close(self):
|
||||
"""Closes the connection."""
|
||||
raise NotImplementedError
|
||||
|
||||
def clone(self):
|
||||
"""Creates a copy of this Connection."""
|
||||
return self.__class__(
|
||||
loop=self._loop,
|
||||
proxy=self._proxy,
|
||||
timeout=self._timeout
|
||||
)
|
||||
|
||||
@abc.abstractmethod
|
||||
async def recv(self):
|
||||
"""Receives and unpacks a message"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
async def send(self, message):
|
||||
"""Encapsulates and sends the given message"""
|
||||
raise NotImplementedError
|
180
telethon/network/connection/connection.py
Normal file
180
telethon/network/connection/connection.py
Normal file
|
@ -0,0 +1,180 @@
|
|||
import abc
|
||||
import asyncio
|
||||
import logging
|
||||
import socket
|
||||
import ssl as ssl_mod
|
||||
|
||||
__log__ = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Connection(abc.ABC):
|
||||
"""
|
||||
The `Connection` class is a wrapper around ``asyncio.open_connection``.
|
||||
|
||||
Subclasses are meant to communicate with this class through a queue.
|
||||
|
||||
This class provides a reliable interface that will stay connected
|
||||
under any conditions for as long as the user doesn't disconnect or
|
||||
the input parameters to auto-reconnect dictate otherwise.
|
||||
"""
|
||||
def __init__(self, ip, port, *, loop, proxy=None):
|
||||
self._ip = ip
|
||||
self._port = port
|
||||
self._loop = loop
|
||||
self._proxy = proxy
|
||||
self._reader = None
|
||||
self._writer = None
|
||||
self._disconnected = asyncio.Event(loop=loop)
|
||||
self._disconnected.set()
|
||||
self._disconnected_future = None
|
||||
self._send_task = None
|
||||
self._recv_task = None
|
||||
self._send_queue = asyncio.Queue(1)
|
||||
self._recv_queue = asyncio.Queue(1)
|
||||
|
||||
async def connect(self, timeout=None, ssl=None):
|
||||
"""
|
||||
Establishes a connection with the server.
|
||||
"""
|
||||
if not self._proxy:
|
||||
self._reader, self._writer = await asyncio.wait_for(
|
||||
asyncio.open_connection(
|
||||
self._ip, self._port, loop=self._loop, ssl=ssl),
|
||||
loop=self._loop, timeout=timeout
|
||||
)
|
||||
else:
|
||||
import socks
|
||||
if ':' in self._ip:
|
||||
mode, address = socket.AF_INET6, (self._ip, self._port, 0, 0)
|
||||
else:
|
||||
mode, address = socket.AF_INET, (self._ip, self._port)
|
||||
|
||||
s = socks.socksocket(mode, socket.SOCK_STREAM)
|
||||
if isinstance(self._proxy, dict):
|
||||
s.set_proxy(**self._proxy)
|
||||
else:
|
||||
s.set_proxy(*self._proxy)
|
||||
|
||||
s.setblocking(False)
|
||||
await asyncio.wait_for(
|
||||
self._loop.sock_connect(s, address),
|
||||
timeout=timeout,
|
||||
loop=self._loop
|
||||
)
|
||||
if ssl:
|
||||
self._socket.settimeout(timeout)
|
||||
self._socket = ssl_mod.wrap_socket(
|
||||
s,
|
||||
do_handshake_on_connect=True,
|
||||
ssl_version=ssl_mod.PROTOCOL_SSLv23,
|
||||
ciphers='ADH-AES256-SHA'
|
||||
)
|
||||
self._socket.setblocking(False)
|
||||
|
||||
self._reader, self._writer = await asyncio.open_connection(
|
||||
self._ip, self._port, loop=self._loop, sock=s
|
||||
)
|
||||
|
||||
self._disconnected.clear()
|
||||
self._disconnected_future = None
|
||||
self._send_task = self._loop.create_task(self._send_loop())
|
||||
self._recv_task = self._loop.create_task(self._recv_loop())
|
||||
|
||||
def disconnect(self):
|
||||
"""
|
||||
Disconnects from the server.
|
||||
"""
|
||||
self._disconnected.set()
|
||||
if self._send_task:
|
||||
self._send_task.cancel()
|
||||
|
||||
if self._recv_task:
|
||||
self._recv_task.cancel()
|
||||
|
||||
if self._writer:
|
||||
self._writer.close()
|
||||
|
||||
@property
|
||||
def disconnected(self):
|
||||
if not self._disconnected_future:
|
||||
self._disconnected_future = \
|
||||
self._loop.create_task(self._disconnected.wait())
|
||||
return self._disconnected_future
|
||||
|
||||
def clone(self):
|
||||
"""
|
||||
Creates a clone of the connection.
|
||||
"""
|
||||
return self.__class__(self._ip, self._port, loop=self._loop)
|
||||
|
||||
def send(self, data):
|
||||
"""
|
||||
Sends a packet of data through this connection mode.
|
||||
|
||||
This method returns a coroutine.
|
||||
"""
|
||||
return self._send_queue.put(data)
|
||||
|
||||
async def recv(self):
|
||||
"""
|
||||
Receives a packet of data through this connection mode.
|
||||
|
||||
This method returns a coroutine.
|
||||
"""
|
||||
ok, result = await self._recv_queue.get()
|
||||
if ok:
|
||||
return result
|
||||
else:
|
||||
raise result from None
|
||||
|
||||
async def _send_loop(self):
|
||||
"""
|
||||
This loop is constantly popping items off the queue to send them.
|
||||
"""
|
||||
try:
|
||||
while not self._disconnected.is_set():
|
||||
self._send(await self._send_queue.get())
|
||||
await self._writer.drain()
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
except Exception:
|
||||
logging.exception('Unhandled exception in the sending loop')
|
||||
self.disconnect()
|
||||
|
||||
async def _recv_loop(self):
|
||||
"""
|
||||
This loop is constantly putting items on the queue as they're read.
|
||||
"""
|
||||
try:
|
||||
while not self._disconnected.is_set():
|
||||
data = await self._recv()
|
||||
await self._recv_queue.put((True, data))
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
except Exception as e:
|
||||
await self._recv_queue.put((False, e))
|
||||
self.disconnect()
|
||||
|
||||
@abc.abstractmethod
|
||||
def _send(self, data):
|
||||
"""
|
||||
This method should be implemented differently under each
|
||||
connection mode and serialize the data into the packet
|
||||
the way it should be sent through `self._writer`.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
async def _recv(self):
|
||||
"""
|
||||
This method should be implemented differently under each
|
||||
connection mode and deserialize the data from the packet
|
||||
the way it should be read from `self._reader`.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def __str__(self):
|
||||
return '{}:{}/{}'.format(
|
||||
self._ip, self._port,
|
||||
self.__class__.__name__.replace('Connection', '')
|
||||
)
|
|
@ -1,62 +1,34 @@
|
|||
import errno
|
||||
import ssl
|
||||
import asyncio
|
||||
|
||||
from .common import Connection
|
||||
from ...extensions import TcpClient
|
||||
from .connection import Connection
|
||||
|
||||
|
||||
SSL_PORT = 443
|
||||
|
||||
|
||||
class ConnectionHttp(Connection):
|
||||
def __init__(self, *, loop, timeout, proxy=None):
|
||||
super().__init__(loop=loop, timeout=timeout, proxy=proxy)
|
||||
self.conn = TcpClient(
|
||||
timeout=self._timeout, loop=self._loop, proxy=self._proxy,
|
||||
ssl=dict(ssl_version=ssl.PROTOCOL_SSLv23, ciphers='ADH-AES256-SHA')
|
||||
)
|
||||
self.read = self.conn.read
|
||||
self.write = self.conn.write
|
||||
self._host = None
|
||||
async def connect(self, timeout=None, ssl=None):
|
||||
await super().connect(timeout=timeout, ssl=self._port == SSL_PORT)
|
||||
|
||||
async def connect(self, ip, port):
|
||||
self._host = '{}:{}'.format(ip, port)
|
||||
try:
|
||||
await self.conn.connect(ip, port)
|
||||
except OSError as e:
|
||||
if e.errno == errno.EISCONN:
|
||||
return # Already connected, no need to re-set everything up
|
||||
else:
|
||||
raise
|
||||
|
||||
def get_timeout(self):
|
||||
return self.conn.timeout
|
||||
|
||||
def is_connected(self):
|
||||
return self.conn.is_connected
|
||||
|
||||
async def close(self):
|
||||
self.conn.close()
|
||||
|
||||
async def recv(self):
|
||||
while True:
|
||||
line = await self._read_line()
|
||||
if line.lower().startswith(b'content-length: '):
|
||||
await self.read(2)
|
||||
length = int(line[16:-2])
|
||||
return await self.read(length)
|
||||
|
||||
async def _read_line(self):
|
||||
newline = ord('\n')
|
||||
line = await self.read(1)
|
||||
while line[-1] != newline:
|
||||
line += await self.read(1)
|
||||
return line
|
||||
|
||||
async def send(self, message):
|
||||
await self.write(
|
||||
def _send(self, message):
|
||||
self._writer.write(
|
||||
'POST /api HTTP/1.1\r\n'
|
||||
'Host: {}\r\n'
|
||||
'Host: {}:{}\r\n'
|
||||
'Content-Type: application/x-www-form-urlencoded\r\n'
|
||||
'Connection: keep-alive\r\n'
|
||||
'Keep-Alive: timeout=100000, max=10000000\r\n'
|
||||
'Content-Length: {}\r\n\r\n'.format(self._host, len(message))
|
||||
'Content-Length: {}\r\n\r\n'
|
||||
.format(self._ip, self._port, len(message))
|
||||
.encode('ascii') + message
|
||||
)
|
||||
|
||||
async def _recv(self):
|
||||
while True:
|
||||
line = await self._reader.readline()
|
||||
if not line or line[-1] != b'\n':
|
||||
raise asyncio.IncompleteReadError(line, None)
|
||||
|
||||
if line.lower().startswith(b'content-length: '):
|
||||
await self._reader.readexactly(2)
|
||||
length = int(line[16:-2])
|
||||
return await self._reader.readexactly(length)
|
||||
|
|
|
@ -1,31 +1,44 @@
|
|||
import struct
|
||||
|
||||
from .tcpfull import ConnectionTcpFull
|
||||
from .connection import Connection
|
||||
|
||||
|
||||
class ConnectionTcpAbridged(ConnectionTcpFull):
|
||||
class ConnectionTcpAbridged(Connection):
|
||||
"""
|
||||
This is the mode with the lowest overhead, as it will
|
||||
only require 1 byte if the packet length is less than
|
||||
508 bytes (127 << 2, which is very common).
|
||||
"""
|
||||
async def connect(self, ip, port):
|
||||
result = await super().connect(ip, port)
|
||||
await self.conn.write(b'\xef')
|
||||
return result
|
||||
async def connect(self, timeout=None, ssl=None):
|
||||
await super().connect(timeout=timeout, ssl=ssl)
|
||||
self._writer.write(b'\xef')
|
||||
await self._writer.drain()
|
||||
|
||||
async def recv(self):
|
||||
length = struct.unpack('<B', await self.read(1))[0]
|
||||
if length >= 127:
|
||||
length = struct.unpack('<i', await self.read(3) + b'\0')[0]
|
||||
def _write(self, data):
|
||||
"""
|
||||
Define wrapper write methods for `TcpObfuscated` to override.
|
||||
"""
|
||||
self._writer.write(data)
|
||||
|
||||
return await self.read(length << 2)
|
||||
async def _read(self, n):
|
||||
"""
|
||||
Define wrapper read methods for `TcpObfuscated` to override.
|
||||
"""
|
||||
return await self._reader.readexactly(n)
|
||||
|
||||
async def send(self, message):
|
||||
length = len(message) >> 2
|
||||
def _send(self, data):
|
||||
length = len(data) >> 2
|
||||
if length < 127:
|
||||
length = struct.pack('B', length)
|
||||
else:
|
||||
length = b'\x7f' + int.to_bytes(length, 3, 'little')
|
||||
|
||||
await self.write(length + message)
|
||||
self._write(length + data)
|
||||
|
||||
async def _recv(self):
|
||||
length = struct.unpack('<B', await self._read(1))[0]
|
||||
if length >= 127:
|
||||
length = struct.unpack(
|
||||
'<i', await self._read(3) + b'\0')[0]
|
||||
|
||||
return await self._read(length << 2)
|
||||
|
|
|
@ -1,10 +1,8 @@
|
|||
import errno
|
||||
import struct
|
||||
from zlib import crc32
|
||||
|
||||
from .common import Connection
|
||||
from .connection import Connection
|
||||
from ...errors import InvalidChecksumError
|
||||
from ...extensions import TcpClient
|
||||
|
||||
|
||||
class ConnectionTcpFull(Connection):
|
||||
|
@ -12,39 +10,27 @@ class ConnectionTcpFull(Connection):
|
|||
Default Telegram mode. Sends 12 additional bytes and
|
||||
needs to calculate the CRC value of the packet itself.
|
||||
"""
|
||||
def __init__(self, *, loop, timeout, proxy=None):
|
||||
super().__init__(loop=loop, timeout=timeout, proxy=proxy)
|
||||
self._send_counter = 0
|
||||
self.conn = TcpClient(
|
||||
timeout=self._timeout, loop=self._loop, proxy=self._proxy
|
||||
)
|
||||
self.read = self.conn.read
|
||||
self.write = self.conn.write
|
||||
|
||||
async def connect(self, ip, port):
|
||||
try:
|
||||
await self.conn.connect(ip, port)
|
||||
except OSError as e:
|
||||
if e.errno == errno.EISCONN:
|
||||
return # Already connected, no need to re-set everything up
|
||||
else:
|
||||
raise
|
||||
|
||||
def __init__(self, ip, port, *, loop, proxy=None):
|
||||
super().__init__(ip, port, loop=loop, proxy=proxy)
|
||||
self._send_counter = 0
|
||||
|
||||
def get_timeout(self):
|
||||
return self.conn.timeout
|
||||
async def connect(self, timeout=None, ssl=None):
|
||||
await super().connect(timeout=timeout, ssl=ssl)
|
||||
self._send_counter = 0 # Important or Telegram won't reply
|
||||
|
||||
def is_connected(self):
|
||||
return self.conn.is_connected
|
||||
def _send(self, data):
|
||||
# https://core.telegram.org/mtproto#tcp-transport
|
||||
# total length, sequence number, packet and checksum (CRC32)
|
||||
length = len(data) + 12
|
||||
data = struct.pack('<ii', length, self._send_counter) + data
|
||||
crc = struct.pack('<I', crc32(data))
|
||||
self._send_counter += 1
|
||||
self._writer.write(data + crc)
|
||||
|
||||
async def close(self):
|
||||
self.conn.close()
|
||||
|
||||
async def recv(self):
|
||||
packet_len_seq = await self.read(8) # 4 and 4
|
||||
async def _recv(self):
|
||||
packet_len_seq = await self._reader.readexactly(8) # 4 and 4
|
||||
packet_len, seq = struct.unpack('<ii', packet_len_seq)
|
||||
body = await self.read(packet_len - 8)
|
||||
body = await self._reader.readexactly(packet_len - 8)
|
||||
checksum = struct.unpack('<I', body[-4:])[0]
|
||||
body = body[:-4]
|
||||
|
||||
|
@ -53,12 +39,3 @@ class ConnectionTcpFull(Connection):
|
|||
raise InvalidChecksumError(checksum, valid_checksum)
|
||||
|
||||
return body
|
||||
|
||||
async def send(self, message):
|
||||
# https://core.telegram.org/mtproto#tcp-transport
|
||||
# total length, sequence number, packet and checksum (CRC32)
|
||||
length = len(message) + 12
|
||||
data = struct.pack('<ii', length, self._send_counter) + message
|
||||
crc = struct.pack('<I', crc32(data))
|
||||
self._send_counter += 1
|
||||
await self.write(data + crc)
|
||||
|
|
|
@ -1,20 +1,21 @@
|
|||
import struct
|
||||
|
||||
from .tcpfull import ConnectionTcpFull
|
||||
from .connection import Connection
|
||||
|
||||
|
||||
class ConnectionTcpIntermediate(ConnectionTcpFull):
|
||||
class ConnectionTcpIntermediate(Connection):
|
||||
"""
|
||||
Intermediate mode between `ConnectionTcpFull` and `ConnectionTcpAbridged`.
|
||||
Always sends 4 extra bytes for the packet length.
|
||||
"""
|
||||
async def connect(self, ip, port):
|
||||
result = await super().connect(ip, port)
|
||||
await self.conn.write(b'\xee\xee\xee\xee')
|
||||
return result
|
||||
async def connect(self, timeout=None, ssl=None):
|
||||
await super().connect(timeout=timeout, ssl=ssl)
|
||||
self._writer.write(b'\xee\xee\xee\xee')
|
||||
await self._writer.drain()
|
||||
|
||||
async def recv(self):
|
||||
return await self.read(struct.unpack('<i', await self.read(4))[0])
|
||||
def _send(self, data):
|
||||
self._writer.write(struct.pack('<i', len(data)) + data)
|
||||
|
||||
async def send(self, message):
|
||||
await self.write(struct.pack('<i', len(message)) + message)
|
||||
async def _recv(self):
|
||||
return await self._reader.readexactly(
|
||||
struct.unpack('<i', await self._reader.readexactly(4))[0])
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
import os
|
||||
|
||||
from .tcpabridged import ConnectionTcpAbridged
|
||||
from .tcpfull import ConnectionTcpFull
|
||||
from ...crypto import AESModeCTR
|
||||
|
||||
|
||||
|
@ -11,16 +10,22 @@ class ConnectionTcpObfuscated(ConnectionTcpAbridged):
|
|||
every message with a randomly generated key using the
|
||||
AES-CTR mode so the packets are harder to discern.
|
||||
"""
|
||||
def __init__(self, *, loop, timeout, proxy=None):
|
||||
super().__init__(loop=loop, timeout=timeout, proxy=proxy)
|
||||
self._aes_encrypt, self._aes_decrypt = None, None
|
||||
self.read = lambda s: self._aes_decrypt.encrypt(self.conn.read(s))
|
||||
self.write = lambda d: self.conn.write(self._aes_encrypt.encrypt(d))
|
||||
def __init__(self, ip, port, *, loop, proxy=None):
|
||||
super().__init__(ip, port, loop=loop, proxy=proxy)
|
||||
self._aes_encrypt = None
|
||||
self._aes_decrypt = None
|
||||
|
||||
def _write(self, data):
|
||||
self._writer.write(self._aes_encrypt.encrypt(data))
|
||||
|
||||
async def _read(self, n):
|
||||
return self._aes_decrypt.encrypt(await self._reader.readexactly(n))
|
||||
|
||||
async def connect(self, timeout=None, ssl=None):
|
||||
await super().connect(timeout=timeout, ssl=ssl)
|
||||
|
||||
async def connect(self, ip, port):
|
||||
result = await ConnectionTcpFull.connect(self, ip, port)
|
||||
# Obfuscated messages secrets cannot start with any of these
|
||||
keywords = (b'PVrG', b'GET ', b'POST', b'\xee' * 4)
|
||||
keywords = (b'PVrG', b'GET ', b'POST', b'\xee\xee\xee\xee')
|
||||
while True:
|
||||
random = os.urandom(64)
|
||||
if (random[0] != b'\xef' and
|
||||
|
@ -28,11 +33,11 @@ class ConnectionTcpObfuscated(ConnectionTcpAbridged):
|
|||
random[4:4] != b'\0\0\0\0'):
|
||||
break
|
||||
|
||||
random = list(random)
|
||||
random = bytearray(random)
|
||||
random[56] = random[57] = random[58] = random[59] = 0xef
|
||||
random_reversed = random[55:7:-1] # Reversed (8, len=48)
|
||||
|
||||
# encryption has "continuous buffer" enabled
|
||||
# Encryption has "continuous buffer" enabled
|
||||
encrypt_key = bytes(random[8:40])
|
||||
encrypt_iv = bytes(random[40:56])
|
||||
decrypt_key = bytes(random_reversed[:32])
|
||||
|
@ -42,5 +47,5 @@ class ConnectionTcpObfuscated(ConnectionTcpAbridged):
|
|||
self._aes_decrypt = AESModeCTR(decrypt_key, decrypt_iv)
|
||||
|
||||
random[56:64] = self._aes_encrypt.encrypt(bytes(random))[56:64]
|
||||
await self.conn.write(bytes(random))
|
||||
return result
|
||||
self._writer.write(random)
|
||||
await self._writer.drain()
|
||||
|
|
158
telethon/network/mtprotolayer.py
Normal file
158
telethon/network/mtprotolayer.py
Normal file
|
@ -0,0 +1,158 @@
|
|||
import io
|
||||
import logging
|
||||
import struct
|
||||
|
||||
from .mtprotostate import MTProtoState
|
||||
from ..tl import TLRequest
|
||||
from ..tl.core.tlmessage import TLMessage
|
||||
from ..tl.core.messagecontainer import MessageContainer
|
||||
|
||||
__log__ = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MTProtoLayer:
|
||||
"""
|
||||
This class is the message encryption layer between the methods defined
|
||||
in the schema and the response objects. It also holds the necessary state
|
||||
necessary for this encryption to happen.
|
||||
|
||||
The `connection` parameter is through which these messages will be sent
|
||||
and received.
|
||||
|
||||
The `auth_key` must be a valid authorization key which will be used to
|
||||
encrypt these messages. This class is not responsible for generating them.
|
||||
"""
|
||||
def __init__(self, connection, auth_key):
|
||||
self._connection = connection
|
||||
self._state = MTProtoState(auth_key)
|
||||
|
||||
def connect(self, timeout=None):
|
||||
"""
|
||||
Wrapper for ``self._connection.connect()``.
|
||||
"""
|
||||
return self._connection.connect(timeout=timeout)
|
||||
|
||||
def disconnect(self):
|
||||
"""
|
||||
Wrapper for ``self._connection.disconnect()``.
|
||||
"""
|
||||
self._connection.disconnect()
|
||||
|
||||
def reset_state(self):
|
||||
self._state = MTProtoState(self._state.auth_key)
|
||||
|
||||
async def send(self, state_list):
|
||||
"""
|
||||
The list of `RequestState` that will be sent. They will
|
||||
be updated with their new message and container IDs.
|
||||
|
||||
Nested lists imply an order is required for the messages in them.
|
||||
Message containers will be used if there is more than one item.
|
||||
"""
|
||||
for data in filter(None, self._pack_state_list(state_list)):
|
||||
await self._connection.send(self._state.encrypt_message_data(data))
|
||||
|
||||
async def recv(self):
|
||||
"""
|
||||
Reads a single message from the network, decrypts it and returns it.
|
||||
"""
|
||||
body = await self._connection.recv()
|
||||
return self._state.decrypt_message_data(body)
|
||||
|
||||
def _pack_state_list(self, state_list):
|
||||
"""
|
||||
The list of `RequestState` that will be sent. They will
|
||||
be updated with their new message and container IDs.
|
||||
|
||||
Packs all their serialized data into a message (possibly
|
||||
nested inside another message and message container) and
|
||||
returns the serialized message data.
|
||||
"""
|
||||
# Note that the simplest case is writing a single query data into
|
||||
# a message, and returning the message data and ID. For efficiency
|
||||
# purposes this method supports more than one message and automatically
|
||||
# uses containers if deemed necessary.
|
||||
#
|
||||
# Technically the message and message container classes could be used
|
||||
# to store and serialize the data. However, to keep the context local
|
||||
# and relevant to the only place where such feature is actually used,
|
||||
# this is not done.
|
||||
#
|
||||
# When iterating over the state_list there are two branches, one
|
||||
# being just a state and the other being a list so the inner states
|
||||
# depend on each other. In either case, if the packed size exceeds
|
||||
# the maximum container size, it must be sent. This code is non-
|
||||
# trivial so it has been factored into an inner function.
|
||||
#
|
||||
# A new buffer instance will be used once the size should be "flushed"
|
||||
buffer = io.BytesIO()
|
||||
# The batch of requests sent in a single buffer-flush. We need to
|
||||
# remember which states were written to set their container ID.
|
||||
batch = []
|
||||
# The currently written size. Reset when it exceeds the maximum.
|
||||
size = 0
|
||||
|
||||
def write_state(state, after_id=None):
|
||||
nonlocal buffer, batch, size
|
||||
if state:
|
||||
batch.append(state)
|
||||
size += len(state.data) + TLMessage.SIZE_OVERHEAD
|
||||
|
||||
# Flush whenever the current size exceeds the maximum,
|
||||
# or if there's no state, which indicates force flush.
|
||||
if not state or size > MessageContainer.MAXIMUM_SIZE:
|
||||
size -= MessageContainer.MAXIMUM_SIZE
|
||||
if len(batch) > 1:
|
||||
# Inlined code to pack several messages into a container
|
||||
data = struct.pack(
|
||||
'<Ii', MessageContainer.CONSTRUCTOR_ID, len(batch)
|
||||
) + buffer.getvalue()
|
||||
buffer = io.BytesIO()
|
||||
container_id = self._state.write_data_as_message(
|
||||
buffer, data, content_related=False
|
||||
)
|
||||
for s in batch:
|
||||
s.container_id = container_id
|
||||
|
||||
# At this point it's either a single msg or a msg + container
|
||||
data = buffer.getvalue()
|
||||
__log__.debug('Packed %d message(s) in %d bytes for sending',
|
||||
len(batch), len(data))
|
||||
batch.clear()
|
||||
buffer = io.BytesIO()
|
||||
return data
|
||||
|
||||
if not state:
|
||||
return # Just forcibly flushing
|
||||
|
||||
# If even after flushing it still exceeds the maximum size,
|
||||
# this message payload cannot be sent. Telegram would forcibly
|
||||
# close the connection, and the message would never be confirmed.
|
||||
if size > MessageContainer.MAXIMUM_SIZE:
|
||||
state.future.set_exception(
|
||||
ValueError('Request payload is too big'))
|
||||
return
|
||||
|
||||
# This is the only requirement to make this work.
|
||||
state.msg_id = self._state.write_data_as_message(
|
||||
buffer, state.data, isinstance(state.request, TLRequest),
|
||||
after_id=after_id
|
||||
)
|
||||
__log__.debug('Assigned msg_id = %d to %s (%x)',
|
||||
state.msg_id, state.request.__class__.__name__,
|
||||
id(state.request))
|
||||
|
||||
# TODO Yield in the inner loop -> Telegram "Invalid container". Why?
|
||||
for state in state_list:
|
||||
if not isinstance(state, list):
|
||||
yield write_state(state)
|
||||
else:
|
||||
after_id = None
|
||||
for s in state:
|
||||
yield write_state(s, after_id)
|
||||
after_id = s.msg_id
|
||||
|
||||
yield write_state(None)
|
||||
|
||||
def __str__(self):
|
||||
return str(self._connection)
|
|
@ -30,7 +30,7 @@ class MTProtoPlainSender:
|
|||
body = bytes(request)
|
||||
msg_id = self._state._get_new_msg_id()
|
||||
await self._connection.send(
|
||||
struct.pack('<QQi', 0, msg_id, len(body)) + body
|
||||
struct.pack('<qqi', 0, msg_id, len(body)) + body
|
||||
)
|
||||
|
||||
body = await self._connection.recv()
|
||||
|
|
|
@ -1,13 +1,19 @@
|
|||
import asyncio
|
||||
import collections
|
||||
import logging
|
||||
|
||||
from . import MTProtoPlainSender, authenticator
|
||||
from . import authenticator
|
||||
from .mtprotolayer import MTProtoLayer
|
||||
from .mtprotoplainsender import MTProtoPlainSender
|
||||
from .requeststate import RequestState
|
||||
from ..tl.tlobject import TLRequest
|
||||
from .. import utils
|
||||
from ..errors import (
|
||||
BadMessageError, TypeNotFoundError, BrokenAuthKeyError, SecurityError,
|
||||
rpc_message_to_error
|
||||
BadMessageError, BrokenAuthKeyError, SecurityError, TypeNotFoundError,
|
||||
InvalidChecksumError, rpc_message_to_error
|
||||
)
|
||||
from ..extensions import BinaryReader
|
||||
from ..helpers import _ReadyQueue
|
||||
from ..tl.core import RpcResult, MessageContainer, GzipPacked
|
||||
from ..tl.functions.auth import LogOutRequest
|
||||
from ..tl.types import (
|
||||
|
@ -20,12 +26,6 @@ from ..utils import AsyncClassWrapper
|
|||
__log__ = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Place this object in the send queue when a reconnection is needed
|
||||
# so there is an item to read and we can early quit the loop, since
|
||||
# without this it will block until there's something in the queue.
|
||||
_reconnect_sentinel = object()
|
||||
|
||||
|
||||
class MTProtoSender:
|
||||
"""
|
||||
MTProto Mobile Protocol sender
|
||||
|
@ -41,16 +41,15 @@ class MTProtoSender:
|
|||
A new authorization key will be generated on connection if no other
|
||||
key exists yet.
|
||||
"""
|
||||
def __init__(self, state, connection, loop, *,
|
||||
retries=5, auto_reconnect=True, update_callback=None,
|
||||
def __init__(self, loop, *,
|
||||
retries=5, auto_reconnect=True, connect_timeout=None,
|
||||
update_callback=None,
|
||||
auth_key_callback=None, auto_reconnect_callback=None):
|
||||
self.state = state
|
||||
self._connection = connection
|
||||
self._connection = None # MTProtoLayer, a.k.a. encrypted connection
|
||||
self._loop = loop
|
||||
self._ip = None
|
||||
self._port = None
|
||||
self._retries = retries
|
||||
self._auto_reconnect = auto_reconnect
|
||||
self._connect_timeout = connect_timeout
|
||||
self._update_callback = update_callback
|
||||
self._auth_key_callback = auth_key_callback
|
||||
self._auto_reconnect_callback = auto_reconnect_callback
|
||||
|
@ -69,27 +68,21 @@ class MTProtoSender:
|
|||
self._send_loop_handle = None
|
||||
self._recv_loop_handle = None
|
||||
|
||||
# Sending something shouldn't block
|
||||
self._send_queue = _ContainerQueue()
|
||||
# Outgoing messages are put in a queue and sent in a batch.
|
||||
# Note that here we're also storing their ``_RequestState``.
|
||||
# Note that it may also store lists (implying order must be kept).
|
||||
self._send_queue = _ReadyQueue(self._loop)
|
||||
|
||||
# Telegram responds to messages out of order. Keep
|
||||
# {id: Message} to set their Future result upon arrival.
|
||||
self._pending_messages = {}
|
||||
# Sent states are remembered until a response is received.
|
||||
self._pending_state = {}
|
||||
|
||||
# Containers are accepted or rejected as a whole when any of
|
||||
# its inner requests are acknowledged. For this purpose we
|
||||
# all the sent containers here.
|
||||
self._pending_containers = []
|
||||
|
||||
# We need to acknowledge every response from Telegram
|
||||
# Responses must be acknowledged, and we can also batch these.
|
||||
self._pending_ack = set()
|
||||
|
||||
# Similar to pending_messages but only for the last ack.
|
||||
# Ack can't be put in the messages because Telegram never
|
||||
# responds to acknowledges (they're just that, acknowledges),
|
||||
# so it would grow to infinite otherwise, but on bad salt it's
|
||||
# necessary to resend them just like everything else.
|
||||
self._last_ack = None
|
||||
# Similar to pending_messages but only for the last acknowledges.
|
||||
# These can't go in pending_messages because no acknowledge for them
|
||||
# is received, but we may still need to resend their state on bad salts.
|
||||
self._last_acks = collections.deque(maxlen=10)
|
||||
|
||||
# Jump table from response ID to method that handles it
|
||||
self._handlers = {
|
||||
|
@ -111,18 +104,15 @@ class MTProtoSender:
|
|||
|
||||
# Public API
|
||||
|
||||
async def connect(self, ip, port):
|
||||
async def connect(self, auth_key, connection):
|
||||
"""
|
||||
Connects to the specified ``ip:port``, and generates a new
|
||||
authorization key for the `MTProtoSender.session` if it does
|
||||
not exist yet.
|
||||
Connects to the specified given connection using the given auth key.
|
||||
"""
|
||||
if self._user_connected:
|
||||
__log__.info('User is already connected!')
|
||||
return
|
||||
|
||||
self._ip = ip
|
||||
self._port = port
|
||||
self._connection = MTProtoLayer(connection, auth_key)
|
||||
self._user_connected = True
|
||||
await self._connect()
|
||||
|
||||
|
@ -134,53 +124,13 @@ class MTProtoSender:
|
|||
Cleanly disconnects the instance from the network, cancels
|
||||
all pending requests, and closes the send and receive loops.
|
||||
"""
|
||||
if not self._user_connected:
|
||||
__log__.info('User is already disconnected!')
|
||||
return
|
||||
|
||||
await self._disconnect()
|
||||
|
||||
async def _disconnect(self, error=None):
|
||||
__log__.info('Disconnecting from {}...'.format(self._ip))
|
||||
self._user_connected = False
|
||||
try:
|
||||
__log__.debug('Closing current connection...')
|
||||
await self._connection.close()
|
||||
finally:
|
||||
__log__.debug('Cancelling {} pending message(s)...'
|
||||
.format(len(self._pending_messages)))
|
||||
for message in self._pending_messages.values():
|
||||
if error and not message.future.done():
|
||||
message.future.set_exception(error)
|
||||
else:
|
||||
message.future.cancel()
|
||||
|
||||
self._pending_messages.clear()
|
||||
self._pending_ack.clear()
|
||||
self._last_ack = None
|
||||
|
||||
if self._send_loop_handle:
|
||||
__log__.debug('Cancelling the send loop...')
|
||||
self._send_loop_handle.cancel()
|
||||
|
||||
if self._recv_loop_handle:
|
||||
__log__.debug('Cancelling the receive loop...')
|
||||
self._recv_loop_handle.cancel()
|
||||
|
||||
__log__.info('Disconnection from {} complete!'.format(self._ip))
|
||||
if self._disconnected and not self._disconnected.done():
|
||||
if error:
|
||||
self._disconnected.set_exception(error)
|
||||
else:
|
||||
self._disconnected.set_result(None)
|
||||
self._disconnect()
|
||||
|
||||
def send(self, request, ordered=False):
|
||||
"""
|
||||
This method enqueues the given request to be sent.
|
||||
|
||||
The request will be wrapped inside a `TLMessage` until its
|
||||
response arrives, and the `Future` response of the `TLMessage`
|
||||
is immediately returned so that one can further ``await`` it:
|
||||
This method enqueues the given request to be sent. Its send
|
||||
state will be saved until a response arrives, and a ``Future``
|
||||
that will be resolved when the response arrives will be returned:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
|
@ -202,23 +152,23 @@ class MTProtoSender:
|
|||
if not self._user_connected:
|
||||
raise ConnectionError('Cannot send requests while disconnected')
|
||||
|
||||
if utils.is_list_like(request):
|
||||
result = []
|
||||
after = None
|
||||
for r in request:
|
||||
message = self.state.create_message(
|
||||
r, loop=self._loop, after=after)
|
||||
|
||||
self._pending_messages[message.msg_id] = message
|
||||
self._send_queue.put_nowait(message)
|
||||
result.append(message.future)
|
||||
after = ordered and message
|
||||
return result
|
||||
if not utils.is_list_like(request):
|
||||
state = RequestState(request, self._loop)
|
||||
self._send_queue.append(state)
|
||||
return state.future
|
||||
else:
|
||||
message = self.state.create_message(request, loop=self._loop)
|
||||
self._pending_messages[message.msg_id] = message
|
||||
self._send_queue.put_nowait(message)
|
||||
return message.future
|
||||
states = []
|
||||
futures = []
|
||||
for req in request:
|
||||
state = RequestState(req, self._loop)
|
||||
states.append(state)
|
||||
futures.append(state.future)
|
||||
if ordered:
|
||||
self._send_queue.append(states)
|
||||
else:
|
||||
self._send_queue.extend(states)
|
||||
|
||||
return futures
|
||||
|
||||
@property
|
||||
def disconnected(self):
|
||||
|
@ -239,12 +189,12 @@ class MTProtoSender:
|
|||
authorization key if necessary, and starting the send and
|
||||
receive loops.
|
||||
"""
|
||||
__log__.info('Connecting to {}:{}...'.format(self._ip, self._port))
|
||||
__log__.info('Connecting to %s...', self._connection)
|
||||
for retry in range(1, self._retries + 1):
|
||||
try:
|
||||
__log__.debug('Connection attempt {}...'.format(retry))
|
||||
await self._connection.connect(self._ip, self._port)
|
||||
except (asyncio.TimeoutError, OSError) as e:
|
||||
await self._connection.connect(timeout=self._connect_timeout)
|
||||
except (OSError, asyncio.TimeoutError) as e:
|
||||
__log__.warning('Attempt {} at connecting failed: {}: {}'
|
||||
.format(retry, type(e).__name__, e))
|
||||
else:
|
||||
|
@ -254,16 +204,17 @@ class MTProtoSender:
|
|||
.format(self._retries))
|
||||
|
||||
__log__.debug('Connection success!')
|
||||
if self.state.auth_key is None:
|
||||
plain = MTProtoPlainSender(self._connection)
|
||||
state = self._connection._state
|
||||
if state.auth_key is None:
|
||||
plain = MTProtoPlainSender(self._connection._connection)
|
||||
for retry in range(1, self._retries + 1):
|
||||
try:
|
||||
__log__.debug('New auth_key attempt {}...'.format(retry))
|
||||
self.state.auth_key, self.state.time_offset =\
|
||||
state.auth_key, state.time_offset =\
|
||||
await authenticator.do_authentication(plain)
|
||||
|
||||
if self._auth_key_callback:
|
||||
await self._auth_key_callback(self.state.auth_key)
|
||||
await self._auth_key_callback(state.auth_key)
|
||||
|
||||
break
|
||||
except (SecurityError, AssertionError) as e:
|
||||
|
@ -284,14 +235,50 @@ class MTProtoSender:
|
|||
# First connection or manual reconnection after a failure
|
||||
if self._disconnected is None or self._disconnected.done():
|
||||
self._disconnected = self._loop.create_future()
|
||||
__log__.info('Connection to {} complete!'.format(self._ip))
|
||||
__log__.info('Connection to %s complete!', self._connection)
|
||||
|
||||
def _disconnect(self, error=None):
|
||||
__log__.info('Disconnecting from %s...', self._connection)
|
||||
self._user_connected = False
|
||||
try:
|
||||
__log__.debug('Closing current connection...')
|
||||
self._connection.disconnect()
|
||||
finally:
|
||||
__log__.debug('Cancelling {} pending message(s)...'
|
||||
.format(len(self._pending_state)))
|
||||
for state in self._pending_state.values():
|
||||
if error and not state.future.done():
|
||||
state.future.set_exception(error)
|
||||
else:
|
||||
state.future.cancel()
|
||||
|
||||
self._pending_state.clear()
|
||||
self._pending_ack.clear()
|
||||
self._last_ack = None
|
||||
|
||||
if self._send_loop_handle:
|
||||
__log__.debug('Cancelling the send loop...')
|
||||
self._send_loop_handle.cancel()
|
||||
|
||||
if self._recv_loop_handle:
|
||||
__log__.debug('Cancelling the receive loop...')
|
||||
self._recv_loop_handle.cancel()
|
||||
|
||||
__log__.info('Disconnection from %s complete!', self._connection)
|
||||
if self._disconnected and not self._disconnected.done():
|
||||
if error:
|
||||
self._disconnected.set_exception(error)
|
||||
else:
|
||||
self._disconnected.set_result(None)
|
||||
|
||||
async def _reconnect(self):
|
||||
"""
|
||||
Cleanly disconnects and then reconnects.
|
||||
"""
|
||||
self._reconnecting = True
|
||||
self._send_queue.put_nowait(_reconnect_sentinel)
|
||||
|
||||
__log__.debug('Closing current connection...')
|
||||
self._connection.disconnect()
|
||||
|
||||
__log__.debug('Awaiting for the send loop before reconnecting...')
|
||||
await self._send_loop_handle
|
||||
|
@ -300,23 +287,27 @@ class MTProtoSender:
|
|||
await self._recv_loop_handle
|
||||
|
||||
__log__.debug('Closing current connection...')
|
||||
await self._connection.close()
|
||||
self._connection.disconnect()
|
||||
|
||||
self._reconnecting = False
|
||||
|
||||
# Start with a clean state (and thus session ID) to avoid old msgs
|
||||
self._connection.reset_state()
|
||||
|
||||
retries = self._retries if self._auto_reconnect else 0
|
||||
for retry in range(1, retries + 1):
|
||||
try:
|
||||
await self._connect()
|
||||
for m in self._pending_messages.values():
|
||||
self._send_queue.put_nowait(m)
|
||||
except ConnectionError:
|
||||
__log__.info('Failed reconnection retry %d/%d', retry, retries)
|
||||
else:
|
||||
self._send_queue.extend(self._pending_state.values())
|
||||
self._pending_state.clear()
|
||||
|
||||
if self._auto_reconnect_callback:
|
||||
self._loop.create_task(self._auto_reconnect_callback())
|
||||
|
||||
break
|
||||
except ConnectionError:
|
||||
__log__.info('Failed reconnection retry %d/%d', retry, retries)
|
||||
else:
|
||||
__log__.error('Failed to reconnect automatically.')
|
||||
await self._disconnect(error=ConnectionError())
|
||||
|
@ -326,23 +317,6 @@ class MTProtoSender:
|
|||
if self._user_connected:
|
||||
self._loop.create_task(self._reconnect())
|
||||
|
||||
def _clean_containers(self, msg_ids):
|
||||
"""
|
||||
Helper method to clean containers from the pending messages
|
||||
once a wrapped msg_id of them has been acknowledged.
|
||||
|
||||
This is the only way we can resend TLMessage(MessageContainer)
|
||||
on bad notifications and also mark them as received once any
|
||||
of their inner TLMessage is acknowledged.
|
||||
"""
|
||||
for i in reversed(range(len(self._pending_containers))):
|
||||
message = self._pending_containers[i]
|
||||
for msg in message.obj.messages:
|
||||
if msg.msg_id in msg_ids:
|
||||
del self._pending_containers[i]
|
||||
del self._pending_messages[message.msg_id]
|
||||
break
|
||||
|
||||
# Loops
|
||||
|
||||
async def _send_loop(self):
|
||||
|
@ -354,67 +328,31 @@ class MTProtoSender:
|
|||
"""
|
||||
while self._user_connected and not self._reconnecting:
|
||||
if self._pending_ack:
|
||||
self._last_ack = self.state.create_message(
|
||||
MsgsAck(list(self._pending_ack)), loop=self._loop
|
||||
)
|
||||
self._send_queue.put_nowait(self._last_ack)
|
||||
ack = RequestState(MsgsAck(list(self._pending_ack)), self._loop)
|
||||
self._send_queue.append(ack)
|
||||
self._last_acks.append(ack)
|
||||
self._pending_ack.clear()
|
||||
|
||||
messages = await self._send_queue.get()
|
||||
if messages == _reconnect_sentinel:
|
||||
if self._reconnecting:
|
||||
break
|
||||
state_list = await self._send_queue.get(
|
||||
self._connection._connection.disconnected)
|
||||
|
||||
if state_list is None:
|
||||
break
|
||||
|
||||
try:
|
||||
await self._connection.send(state_list)
|
||||
except Exception:
|
||||
__log__.exception('Unhandled error while sending data')
|
||||
continue
|
||||
|
||||
for state in state_list:
|
||||
if not isinstance(state, list):
|
||||
if isinstance(state.request, TLRequest):
|
||||
self._pending_state[state.msg_id] = state
|
||||
else:
|
||||
continue
|
||||
|
||||
if isinstance(messages, list):
|
||||
message = self.state.create_message(
|
||||
MessageContainer(messages), loop=self._loop)
|
||||
|
||||
self._pending_messages[message.msg_id] = message
|
||||
self._pending_containers.append(message)
|
||||
else:
|
||||
message = messages
|
||||
messages = [message]
|
||||
|
||||
__log__.debug(
|
||||
'Packing %d outgoing message(s) %s...', len(messages),
|
||||
', '.join(x.obj.__class__.__name__ for x in messages)
|
||||
)
|
||||
body = self.state.pack_message(message)
|
||||
|
||||
while not any(m.future.cancelled() for m in messages):
|
||||
try:
|
||||
__log__.debug('Sending {} bytes...'.format(len(body)))
|
||||
await self._connection.send(body)
|
||||
break
|
||||
except asyncio.TimeoutError:
|
||||
continue
|
||||
except asyncio.CancelledError:
|
||||
return
|
||||
except Exception as e:
|
||||
if isinstance(e, ConnectionError):
|
||||
__log__.info('Connection reset while sending %s', e)
|
||||
elif isinstance(e, OSError):
|
||||
__log__.warning('OSError while sending %s', e)
|
||||
else:
|
||||
__log__.exception('Unhandled exception while receiving')
|
||||
await asyncio.sleep(1, loop=self._loop)
|
||||
|
||||
self._start_reconnect()
|
||||
break
|
||||
else:
|
||||
# Remove the cancelled messages from pending
|
||||
__log__.info('Some futures were cancelled, aborted send')
|
||||
self._clean_containers([m.msg_id for m in messages])
|
||||
for m in messages:
|
||||
if m.future.cancelled():
|
||||
self._pending_messages.pop(m.msg_id, None)
|
||||
else:
|
||||
self._send_queue.put_nowait(m)
|
||||
|
||||
__log__.debug('Outgoing messages {} sent!'
|
||||
.format(', '.join(str(m.msg_id) for m in messages)))
|
||||
for s in state:
|
||||
if isinstance(s.request, TLRequest):
|
||||
self._pending_state[s.msg_id] = s
|
||||
|
||||
async def _recv_loop(self):
|
||||
"""
|
||||
|
@ -424,68 +362,44 @@ class MTProtoSender:
|
|||
Besides `connect`, only this method ever receives data.
|
||||
"""
|
||||
while self._user_connected and not self._reconnecting:
|
||||
__log__.debug('Receiving items from the network...')
|
||||
try:
|
||||
__log__.debug('Receiving items from the network...')
|
||||
body = await self._connection.recv()
|
||||
except asyncio.TimeoutError:
|
||||
message = await self._connection.recv()
|
||||
except TypeNotFoundError as e:
|
||||
__log__.info('Type %08x not found, remaining data %r',
|
||||
e.invalid_constructor_id, e.remaining)
|
||||
continue
|
||||
except asyncio.CancelledError:
|
||||
return
|
||||
except Exception as e:
|
||||
if isinstance(e, ConnectionError):
|
||||
__log__.info('Connection reset while receiving %s', e)
|
||||
elif isinstance(e, OSError):
|
||||
__log__.warning('OSError while receiving %s', e)
|
||||
else:
|
||||
__log__.exception('Unhandled exception while receiving')
|
||||
await asyncio.sleep(1, loop=self._loop)
|
||||
|
||||
self._start_reconnect()
|
||||
break
|
||||
|
||||
# TODO Check salt, session_id and sequence_number
|
||||
__log__.debug('Decoding packet of %d bytes...', len(body))
|
||||
try:
|
||||
message = self.state.unpack_message(body)
|
||||
except (BrokenAuthKeyError, BufferError) as e:
|
||||
# The authorization key may be broken if a message was
|
||||
# sent malformed, or if the authkey truly is corrupted.
|
||||
#
|
||||
# There may be a buffer error if Telegram's response was too
|
||||
# short and hence not understood. Reset the authorization key
|
||||
# and try again in either case.
|
||||
#
|
||||
# TODO Is it possible to detect malformed messages vs
|
||||
# an actually broken authkey?
|
||||
__log__.warning('Broken authorization key?: {}'.format(e))
|
||||
self.state.auth_key = None
|
||||
self._start_reconnect()
|
||||
break
|
||||
except SecurityError as e:
|
||||
# A step while decoding had the incorrect data. This message
|
||||
# should not be considered safe and it should be ignored.
|
||||
__log__.warning('Security error while unpacking a '
|
||||
'received message: {}'.format(e))
|
||||
continue
|
||||
except TypeNotFoundError as e:
|
||||
# The payload inside the message was not a known TLObject.
|
||||
__log__.info('Server replied with an unknown type {:08x}: {!r}'
|
||||
.format(e.invalid_constructor_id, e.remaining))
|
||||
'received message: %s', e)
|
||||
continue
|
||||
except InvalidChecksumError as e:
|
||||
__log__.warning(
|
||||
'Invalid checksum on the read packet (was %s expected %s)',
|
||||
e.checksum, e.valid_checksum
|
||||
)
|
||||
except asyncio.CancelledError:
|
||||
return
|
||||
except Exception as e:
|
||||
__log__.exception('Unhandled exception while unpacking %s',e)
|
||||
await asyncio.sleep(1, loop=self._loop)
|
||||
except (BrokenAuthKeyError, BufferError):
|
||||
__log__.info('Broken authorization key; resetting')
|
||||
self._connection._state.auth_key = None
|
||||
self._start_reconnect()
|
||||
return
|
||||
except asyncio.IncompleteReadError:
|
||||
__log__.info('Telegram closed the connection')
|
||||
self._start_reconnect()
|
||||
return
|
||||
except Exception:
|
||||
__log__.exception('Unhandled error while receiving data')
|
||||
self._start_reconnect()
|
||||
return
|
||||
else:
|
||||
try:
|
||||
await self._process_message(message)
|
||||
except asyncio.CancelledError:
|
||||
return
|
||||
except Exception as e:
|
||||
__log__.exception('Unhandled exception while '
|
||||
'processing %s', message)
|
||||
await asyncio.sleep(1, loop=self._loop)
|
||||
except Exception:
|
||||
__log__.exception('Unhandled error while processing msgs')
|
||||
|
||||
# Response Handlers
|
||||
|
||||
|
@ -500,6 +414,30 @@ class MTProtoSender:
|
|||
self._handle_update)
|
||||
await handler(message)
|
||||
|
||||
def _pop_states(self, msg_id):
|
||||
"""
|
||||
Pops the states known to match the given ID from pending messages.
|
||||
|
||||
This method should be used when the response isn't specific.
|
||||
"""
|
||||
state = self._pending_state.pop(msg_id, None)
|
||||
if state:
|
||||
return [state]
|
||||
|
||||
to_pop = []
|
||||
for state in self._pending_state.values():
|
||||
if state.container_id == msg_id:
|
||||
to_pop.append(state.msg_id)
|
||||
|
||||
if to_pop:
|
||||
return [self._pending_state.pop(x) for x in to_pop]
|
||||
|
||||
for ack in self._last_acks:
|
||||
if ack.msg_id == msg_id:
|
||||
return [ack]
|
||||
|
||||
return []
|
||||
|
||||
async def _handle_rpc_result(self, message):
|
||||
"""
|
||||
Handles the result for Remote Procedure Calls:
|
||||
|
@ -509,11 +447,11 @@ class MTProtoSender:
|
|||
This is where the future results for sent requests are set.
|
||||
"""
|
||||
rpc_result = message.obj
|
||||
message = self._pending_messages.pop(rpc_result.req_msg_id, None)
|
||||
state = self._pending_state.pop(rpc_result.req_msg_id, None)
|
||||
__log__.debug('Handling RPC result for message %d',
|
||||
rpc_result.req_msg_id)
|
||||
|
||||
if not message:
|
||||
if not state:
|
||||
# TODO We should not get responses to things we never sent
|
||||
# However receiving a File() with empty bytes is "common".
|
||||
# See #658, #759 and #958. They seem to happen in a container
|
||||
|
@ -529,22 +467,17 @@ class MTProtoSender:
|
|||
|
||||
if rpc_result.error:
|
||||
error = rpc_message_to_error(rpc_result.error)
|
||||
self._send_queue.put_nowait(self.state.create_message(
|
||||
MsgsAck([message.msg_id]), loop=self._loop
|
||||
))
|
||||
self._send_queue.append(
|
||||
RequestState(MsgsAck([state.msg_id]), loop=self._loop))
|
||||
|
||||
if not message.future.cancelled():
|
||||
message.future.set_exception(error)
|
||||
if not state.future.cancelled():
|
||||
state.future.set_exception(error)
|
||||
else:
|
||||
# TODO Would be nice to avoid accessing a per-obj read_result
|
||||
# Instead have a variable that indicated how the result should
|
||||
# be read (an enum) and dispatch to read the result, mostly
|
||||
# always it's just a normal TLObject.
|
||||
with BinaryReader(rpc_result.body) as reader:
|
||||
result = message.obj.read_result(reader)
|
||||
result = state.request.read_result(reader)
|
||||
|
||||
if not message.future.cancelled():
|
||||
message.future.set_result(result)
|
||||
if not state.future.cancelled():
|
||||
state.future.set_result(result)
|
||||
|
||||
async def _handle_container(self, message):
|
||||
"""
|
||||
|
@ -582,9 +515,9 @@ class MTProtoSender:
|
|||
"""
|
||||
pong = message.obj
|
||||
__log__.debug('Handling pong for message %d', pong.msg_id)
|
||||
message = self._pending_messages.pop(pong.msg_id, None)
|
||||
if message:
|
||||
message.future.set_result(pong)
|
||||
state = self._pending_state.pop(pong.msg_id, None)
|
||||
if state:
|
||||
state.future.set_result(pong)
|
||||
|
||||
async def _handle_bad_server_salt(self, message):
|
||||
"""
|
||||
|
@ -596,18 +529,11 @@ class MTProtoSender:
|
|||
"""
|
||||
bad_salt = message.obj
|
||||
__log__.debug('Handling bad salt for message %d', bad_salt.bad_msg_id)
|
||||
self.state.salt = bad_salt.new_server_salt
|
||||
if self._last_ack and bad_salt.bad_msg_id == self._last_ack.msg_id:
|
||||
self._send_queue.put_nowait(self._last_ack)
|
||||
return
|
||||
self._connection._state.salt = bad_salt.new_server_salt
|
||||
states = self._pop_states(bad_salt.bad_msg_id)
|
||||
self._send_queue.extend(states)
|
||||
|
||||
try:
|
||||
self._send_queue.put_nowait(
|
||||
self._pending_messages[bad_salt.bad_msg_id])
|
||||
except KeyError:
|
||||
# May be MsgsAck, those are not saved in pending messages
|
||||
__log__.info('Message %d not resent due to bad salt',
|
||||
bad_salt.bad_msg_id)
|
||||
__log__.debug('%d message(s) will be resent', len(states))
|
||||
|
||||
async def _handle_bad_notification(self, message):
|
||||
"""
|
||||
|
@ -618,44 +544,30 @@ class MTProtoSender:
|
|||
error_code:int = BadMsgNotification;
|
||||
"""
|
||||
bad_msg = message.obj
|
||||
msg = self._pending_messages.get(bad_msg.bad_msg_id)
|
||||
states = self._pop_states(bad_msg.bad_msg_id)
|
||||
|
||||
__log__.debug('Handling bad msg %s', bad_msg)
|
||||
if bad_msg.error_code in (16, 17):
|
||||
# Sent msg_id too low or too high (respectively).
|
||||
# Use the current msg_id to determine the right time offset.
|
||||
to = self.state.update_time_offset(correct_msg_id=message.msg_id)
|
||||
to = self._connection._state.update_time_offset(
|
||||
correct_msg_id=message.msg_id)
|
||||
__log__.info('System clock is wrong, set time offset to %ds', to)
|
||||
|
||||
# Correct the msg_id *of the message to resend*, not all.
|
||||
#
|
||||
# If we correct them all, new "bad message" would not find
|
||||
# the old invalid IDs, causing all awaits to never finish.
|
||||
if msg:
|
||||
del self._pending_messages[msg.msg_id]
|
||||
self.state.update_message_id(msg)
|
||||
self._pending_messages[msg.msg_id] = msg
|
||||
|
||||
elif bad_msg.error_code == 32:
|
||||
# msg_seqno too low, so just pump it up by some "large" amount
|
||||
# TODO A better fix would be to start with a new fresh session ID
|
||||
self.state._sequence += 64
|
||||
self._connection._state._sequence += 64
|
||||
elif bad_msg.error_code == 33:
|
||||
# msg_seqno too high never seems to happen but just in case
|
||||
self.state._sequence -= 16
|
||||
self._connection._state._sequence -= 16
|
||||
else:
|
||||
if msg:
|
||||
del self._pending_messages[msg.msg_id]
|
||||
msg.future.set_exception(BadMessageError(bad_msg.error_code))
|
||||
for state in states:
|
||||
state.future.set_exception(BadMessageError(bad_msg.error_code))
|
||||
return
|
||||
|
||||
# Messages are to be re-sent once we've corrected the issue
|
||||
if msg:
|
||||
self._send_queue.put_nowait(msg)
|
||||
else:
|
||||
# May be MsgsAck, those are not saved in pending messages
|
||||
__log__.info('Message %d not resent due to bad msg',
|
||||
bad_msg.bad_msg_id)
|
||||
self._send_queue.extend(states)
|
||||
__log__.debug('%d messages will be resent due to bad msg', len(states))
|
||||
|
||||
async def _handle_detailed_info(self, message):
|
||||
"""
|
||||
|
@ -690,7 +602,7 @@ class MTProtoSender:
|
|||
"""
|
||||
# TODO https://goo.gl/LMyN7A
|
||||
__log__.debug('Handling new session created')
|
||||
self.state.salt = message.obj.server_salt
|
||||
self._connection._state.salt = message.obj.server_salt
|
||||
|
||||
async def _handle_ack(self, message):
|
||||
"""
|
||||
|
@ -709,14 +621,11 @@ class MTProtoSender:
|
|||
"""
|
||||
ack = message.obj
|
||||
__log__.debug('Handling acknowledge for %s', str(ack.msg_ids))
|
||||
if self._pending_containers:
|
||||
self._clean_containers(ack.msg_ids)
|
||||
|
||||
for msg_id in ack.msg_ids:
|
||||
msg = self._pending_messages.get(msg_id, None)
|
||||
if msg and isinstance(msg.obj, LogOutRequest):
|
||||
del self._pending_messages[msg_id]
|
||||
msg.future.set_result(True)
|
||||
state = self._pending_state.get(msg_id)
|
||||
if state and isinstance(state.request, LogOutRequest):
|
||||
del self._pending_state[msg_id]
|
||||
state.future.set_result(True)
|
||||
|
||||
async def _handle_future_salts(self, message):
|
||||
"""
|
||||
|
@ -729,52 +638,20 @@ class MTProtoSender:
|
|||
# TODO save these salts and automatically adjust to the
|
||||
# correct one whenever the salt in use expires.
|
||||
__log__.debug('Handling future salts for message %d', message.msg_id)
|
||||
msg = self._pending_messages.pop(message.msg_id, None)
|
||||
if msg:
|
||||
msg.future.set_result(message.obj)
|
||||
state = self._pending_state.pop(message.msg_id, None)
|
||||
if state:
|
||||
state.future.set_result(message.obj)
|
||||
|
||||
async def _handle_state_forgotten(self, message):
|
||||
"""
|
||||
Handles both :tl:`MsgsStateReq` and :tl:`MsgResendReq` by
|
||||
enqueuing a :tl:`MsgsStateInfo` to be sent at a later point.
|
||||
"""
|
||||
self.send(MsgsStateInfo(req_msg_id=message.msg_id,
|
||||
info=chr(1) * len(message.obj.msg_ids)))
|
||||
self._send_queue.append(RequestState(MsgsStateInfo(
|
||||
req_msg_id=message.msg_id, info=chr(1) * len(message.obj.msg_ids)),
|
||||
loop=self._loop))
|
||||
|
||||
async def _handle_msg_all(self, message):
|
||||
"""
|
||||
Handles :tl:`MsgsAllInfo` by doing nothing (yet).
|
||||
"""
|
||||
|
||||
|
||||
class _ContainerQueue(asyncio.Queue):
|
||||
"""
|
||||
An asyncio queue that's aware of `MessageContainer` instances.
|
||||
|
||||
The `get` method returns either a single `TLMessage` or a list
|
||||
of them that should be turned into a new `MessageContainer`.
|
||||
|
||||
Instances of this class can be replaced with the simpler
|
||||
``asyncio.Queue`` when needed for testing purposes, and
|
||||
a list won't be returned in said case.
|
||||
"""
|
||||
async def get(self):
|
||||
result = await super().get()
|
||||
if self.empty() or result == _reconnect_sentinel or\
|
||||
isinstance(result.obj, MessageContainer):
|
||||
return result
|
||||
|
||||
size = result.size()
|
||||
result = [result]
|
||||
while not self.empty():
|
||||
item = self.get_nowait()
|
||||
if (item == _reconnect_sentinel or
|
||||
isinstance(item.obj, MessageContainer)
|
||||
or size + item.size() > MessageContainer.MAXIMUM_SIZE):
|
||||
self.put_nowait(item)
|
||||
break
|
||||
else:
|
||||
size += item.size()
|
||||
result.append(item)
|
||||
|
||||
return result
|
||||
|
|
|
@ -8,7 +8,8 @@ from ..crypto import AES
|
|||
from ..errors import SecurityError, BrokenAuthKeyError
|
||||
from ..extensions import BinaryReader
|
||||
from ..tl.core import TLMessage
|
||||
from ..tl.tlobject import TLRequest
|
||||
from ..tl.functions import InvokeAfterMsgRequest
|
||||
from ..tl.core.gzippacked import GzipPacked
|
||||
|
||||
__log__ = logging.getLogger(__name__)
|
||||
|
||||
|
@ -27,6 +28,14 @@ class MTProtoState:
|
|||
for all these is not a good idea as each need their own authkey, and
|
||||
the concept of "copying" sessions with the unnecessary entities or
|
||||
updates state for these connections doesn't make sense.
|
||||
|
||||
While it would be possible to have a `MTProtoPlainState` that does no
|
||||
encryption so that it was usable through the `MTProtoLayer` and thus
|
||||
avoid the need for a `MTProtoPlainSender`, the `MTProtoLayer` is more
|
||||
focused to efficiency and this state is also more advanced (since it
|
||||
supports gzipping and invoking after other message IDs). There are too
|
||||
many methods that would be needed to make it convenient to use for the
|
||||
authentication process, at which point the `MTProtoPlainSender` is better.
|
||||
"""
|
||||
def __init__(self, auth_key):
|
||||
# Session IDs can be random on every connection
|
||||
|
@ -37,20 +46,6 @@ class MTProtoState:
|
|||
self._sequence = 0
|
||||
self._last_msg_id = 0
|
||||
|
||||
def create_message(self, obj, *, loop, after=None):
|
||||
"""
|
||||
Creates a new `telethon.tl.tl_message.TLMessage` from
|
||||
the given `telethon.tl.tlobject.TLObject` instance.
|
||||
"""
|
||||
return TLMessage(
|
||||
msg_id=self._get_new_msg_id(),
|
||||
seq_no=self._get_seq_no(isinstance(obj, TLRequest)),
|
||||
obj=obj,
|
||||
after_id=after.msg_id if after else None,
|
||||
out=True, # Pre-convert the request into bytes
|
||||
loop=loop
|
||||
)
|
||||
|
||||
def update_message_id(self, message):
|
||||
"""
|
||||
Updates the message ID to a new one,
|
||||
|
@ -74,14 +69,31 @@ class MTProtoState:
|
|||
|
||||
return aes_key, aes_iv
|
||||
|
||||
def pack_message(self, message):
|
||||
def write_data_as_message(self, buffer, data, content_related,
|
||||
*, after_id=None):
|
||||
"""
|
||||
Packs the given `telethon.tl.tl_message.TLMessage` using the
|
||||
current authorization key following MTProto 2.0 guidelines.
|
||||
Writes a message containing the given data into buffer.
|
||||
|
||||
See https://core.telegram.org/mtproto/description.
|
||||
Returns the message id.
|
||||
"""
|
||||
data = struct.pack('<qq', self.salt, self.id) + bytes(message)
|
||||
msg_id = self._get_new_msg_id()
|
||||
seq_no = self._get_seq_no(content_related)
|
||||
if after_id is None:
|
||||
body = GzipPacked.gzip_if_smaller(content_related, data)
|
||||
else:
|
||||
body = GzipPacked.gzip_if_smaller(content_related,
|
||||
bytes(InvokeAfterMsgRequest(after_id, data)))
|
||||
|
||||
buffer.write(struct.pack('<qii', msg_id, seq_no, len(body)))
|
||||
buffer.write(body)
|
||||
return msg_id
|
||||
|
||||
def encrypt_message_data(self, data):
|
||||
"""
|
||||
Encrypts the given message data using the current authorization key
|
||||
following MTProto 2.0 guidelines core.telegram.org/mtproto/description.
|
||||
"""
|
||||
data = struct.pack('<qq', self.salt, self.id) + data
|
||||
padding = os.urandom(-(len(data) + 12) % 16 + 12)
|
||||
|
||||
# Being substr(what, offset, length); x = 0 for client
|
||||
|
@ -97,16 +109,18 @@ class MTProtoState:
|
|||
return (key_id + msg_key +
|
||||
AES.encrypt_ige(data + padding, aes_key, aes_iv))
|
||||
|
||||
def unpack_message(self, body):
|
||||
def decrypt_message_data(self, body):
|
||||
"""
|
||||
Inverse of `pack_message` for incoming server messages.
|
||||
Inverse of `encrypt_message_data` for incoming server messages.
|
||||
"""
|
||||
if len(body) < 8:
|
||||
# TODO If len == 4, raise HTTPErrorCode(-little endian int)
|
||||
if body == b'l\xfe\xff\xff':
|
||||
raise BrokenAuthKeyError()
|
||||
else:
|
||||
raise BufferError("Can't decode packet ({})".format(body))
|
||||
|
||||
# TODO Check salt, session_id and sequence_number
|
||||
key_id = struct.unpack('<Q', body[:8])[0]
|
||||
if key_id != self.auth_key.key_id:
|
||||
raise SecurityError('Server replied with an invalid auth key')
|
||||
|
@ -136,7 +150,7 @@ class MTProtoState:
|
|||
# reader isn't used for anything else after this, it's unnecessary.
|
||||
obj = reader.tgread_object()
|
||||
|
||||
return TLMessage(remote_msg_id, remote_sequence, obj, loop=None)
|
||||
return TLMessage(remote_msg_id, remote_sequence, obj)
|
||||
|
||||
def _get_new_msg_id(self):
|
||||
"""
|
||||
|
|
18
telethon/network/requeststate.py
Normal file
18
telethon/network/requeststate.py
Normal file
|
@ -0,0 +1,18 @@
|
|||
import asyncio
|
||||
|
||||
|
||||
class RequestState:
|
||||
"""
|
||||
This request state holds several information relevant to sent messages,
|
||||
in particular the message ID assigned to the request, the container ID
|
||||
it belongs to, the request itself, the request as bytes, and the future
|
||||
result that will eventually be resolved.
|
||||
"""
|
||||
__slots__ = ('container_id', 'msg_id', 'request', 'data', 'future')
|
||||
|
||||
def __init__(self, request, loop):
|
||||
self.container_id = None
|
||||
self.msg_id = None
|
||||
self.request = request
|
||||
self.data = bytes(request)
|
||||
self.future = asyncio.Future(loop=loop)
|
|
@ -11,15 +11,14 @@ class GzipPacked(TLObject):
|
|||
self.data = data
|
||||
|
||||
@staticmethod
|
||||
def gzip_if_smaller(request):
|
||||
def gzip_if_smaller(content_related, data):
|
||||
"""Calls bytes(request), and based on a certain threshold,
|
||||
optionally gzips the resulting data. If the gzipped data is
|
||||
smaller than the original byte array, this is returned instead.
|
||||
|
||||
Note that this only applies to content related requests.
|
||||
"""
|
||||
data = bytes(request)
|
||||
if isinstance(request, TLRequest) and len(data) > 512:
|
||||
if content_related and len(data) > 512:
|
||||
gzipped = bytes(GzipPacked(data))
|
||||
return gzipped if len(gzipped) < len(data) else data
|
||||
else:
|
||||
|
|
|
@ -27,11 +27,6 @@ class MessageContainer(TLObject):
|
|||
],
|
||||
}
|
||||
|
||||
def __bytes__(self):
|
||||
return struct.pack(
|
||||
'<Ii', MessageContainer.CONSTRUCTOR_ID, len(self.messages)
|
||||
) + b''.join(bytes(m) for m in self.messages)
|
||||
|
||||
@classmethod
|
||||
def from_reader(cls, reader):
|
||||
# This assumes that .read_* calls are done in the order they appear
|
||||
|
@ -43,5 +38,5 @@ class MessageContainer(TLObject):
|
|||
before = reader.tell_position()
|
||||
obj = reader.tgread_object() # May over-read e.g. RpcResult
|
||||
reader.set_position(before + length)
|
||||
messages.append(TLMessage(msg_id, seq_no, obj, loop=None))
|
||||
messages.append(TLMessage(msg_id, seq_no, obj))
|
||||
return MessageContainer(messages)
|
||||
|
|
|
@ -1,10 +1,6 @@
|
|||
import asyncio
|
||||
import logging
|
||||
import struct
|
||||
|
||||
from .gzippacked import GzipPacked
|
||||
from .. import TLObject
|
||||
from ..functions import InvokeAfterMsgRequest
|
||||
|
||||
__log__ = logging.getLogger(__name__)
|
||||
|
||||
|
@ -17,83 +13,26 @@ class TLMessage(TLObject):
|
|||
message msg_id:long seqno:int bytes:int body:bytes = Message;
|
||||
|
||||
Each message has its own unique identifier, and the body is simply
|
||||
the serialized request that should be executed on the server. Then
|
||||
Telegram will, at some point, respond with the result for this msg.
|
||||
the serialized request that should be executed on the server, or
|
||||
the response object from Telegram. Since the body is always a valid
|
||||
object, it makes sense to store the object and not the bytes to
|
||||
ease working with them.
|
||||
|
||||
Thus it makes sense that requests and their result are bound to a
|
||||
sent `TLMessage`, and this result can be represented as a `Future`
|
||||
that will eventually be set with either a result, error or cancelled.
|
||||
There is no need to add serializing logic here since that can be
|
||||
inlined and is unlikely to change. Thus these are only needed to
|
||||
encapsulate responses.
|
||||
"""
|
||||
def __init__(self, msg_id, seq_no, obj, *, loop, out=False, after_id=0):
|
||||
SIZE_OVERHEAD = 12
|
||||
|
||||
def __init__(self, msg_id, seq_no, obj):
|
||||
self.msg_id = msg_id
|
||||
self.seq_no = seq_no
|
||||
self.obj = obj
|
||||
self.container_msg_id = None
|
||||
|
||||
# If no loop is given then it is an incoming message.
|
||||
# Only outgoing messages need the future to await them.
|
||||
self.future = loop.create_future() if loop else None
|
||||
|
||||
# After which message ID this one should run. We do this so
|
||||
# InvokeAfterMsgRequest is transparent to the user and we can
|
||||
# easily invoke after while confirming the original request.
|
||||
# TODO Currently we don't update this if another message ID changes
|
||||
self.after_id = after_id
|
||||
|
||||
# There are two use-cases for the TLMessage, outgoing and incoming.
|
||||
# Outgoing messages are meant to be serialized and sent across the
|
||||
# network so it makes sense to pack them as early as possible and
|
||||
# avoid this computation if it needs to be resent, and also shows
|
||||
# serializing-errors as early as possible (foreground task).
|
||||
#
|
||||
# We assume obj won't change so caching the bytes is safe to do.
|
||||
# Caching bytes lets us get the size in a fast way, necessary for
|
||||
# knowing whether a container can be sent (<1MB) or not (too big).
|
||||
#
|
||||
# Incoming messages don't really need this body, but we save the
|
||||
# msg_id and seq_no inside the body for consistency and raise if
|
||||
# one tries to bytes()-ify the entire message (len == 12).
|
||||
if not out:
|
||||
self._body = struct.pack('<qi', msg_id, seq_no)
|
||||
else:
|
||||
try:
|
||||
if self.after_id is None:
|
||||
body = GzipPacked.gzip_if_smaller(self.obj)
|
||||
else:
|
||||
body = GzipPacked.gzip_if_smaller(
|
||||
InvokeAfterMsgRequest(self.after_id, self.obj))
|
||||
except Exception:
|
||||
# struct.pack doesn't give a lot of information about
|
||||
# why it may fail so log the exception AND the object
|
||||
__log__.exception('Failed to pack %s', self.obj)
|
||||
raise
|
||||
|
||||
self._body = struct.pack('<qii', msg_id, seq_no, len(body)) + body
|
||||
|
||||
def to_dict(self):
|
||||
return {
|
||||
'_': 'TLMessage',
|
||||
'msg_id': self.msg_id,
|
||||
'seq_no': self.seq_no,
|
||||
'obj': self.obj,
|
||||
'container_msg_id': self.container_msg_id
|
||||
'obj': self.obj
|
||||
}
|
||||
|
||||
@property
|
||||
def msg_id(self):
|
||||
return struct.unpack('<q', self._body[:8])[0]
|
||||
|
||||
@msg_id.setter
|
||||
def msg_id(self, value):
|
||||
self._body = struct.pack('<q', value) + self._body[8:]
|
||||
|
||||
@property
|
||||
def seq_no(self):
|
||||
return struct.unpack('<i', self._body[8:12])[0]
|
||||
|
||||
def __bytes__(self):
|
||||
if len(self._body) == 12: # msg_id, seqno
|
||||
raise TypeError('Incoming messages should not be bytes()-ed')
|
||||
|
||||
return self._body
|
||||
|
||||
def size(self):
|
||||
return len(self._body)
|
||||
|
|
|
@ -260,10 +260,7 @@ class Conversation(ChatGetter):
|
|||
if isinstance(event, type):
|
||||
event = event()
|
||||
|
||||
# Since we await resolve here we don't need to await resolved.
|
||||
# We know it has already been resolved, unlike when normally
|
||||
# adding an event handler, for which a task is created to resolve.
|
||||
await event.resolve()
|
||||
await event.resolve(self._client)
|
||||
|
||||
counter = Conversation._custom_counter
|
||||
Conversation._custom_counter += 1
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
from . import Draft
|
||||
from .. import TLObject, types
|
||||
from .. import TLObject, types, functions
|
||||
from ... import utils
|
||||
|
||||
|
||||
|
@ -96,6 +96,17 @@ class Dialog:
|
|||
return await self._client.send_message(
|
||||
self.input_entity, *args, **kwargs)
|
||||
|
||||
async def delete(self):
|
||||
if self.is_channel:
|
||||
await self._client(functions.channels.LeaveChannelRequest(
|
||||
self.input_entity))
|
||||
else:
|
||||
if self.is_group:
|
||||
await self._client(functions.messages.DeleteChatUserRequest(
|
||||
self.entity.id, types.InputPeerSelf()))
|
||||
await self._client(functions.messages.DeleteHistoryRequest(
|
||||
self.input_entity, 0))
|
||||
|
||||
def to_dict(self):
|
||||
return {
|
||||
'_': 'Dialog',
|
||||
|
|
|
@ -3,9 +3,10 @@ import datetime
|
|||
from .. import TLObject
|
||||
from ..functions.messages import SaveDraftRequest
|
||||
from ..types import UpdateDraftMessage, DraftMessage
|
||||
from ... import default
|
||||
from ...errors import RPCError
|
||||
from ...extensions import markdown
|
||||
from ...utils import Default, get_peer_id, get_input_peer
|
||||
from ...utils import get_peer_id, get_input_peer
|
||||
|
||||
|
||||
class Draft:
|
||||
|
@ -116,7 +117,7 @@ class Draft:
|
|||
return not self._text
|
||||
|
||||
async def set_message(
|
||||
self, text=None, reply_to=0, parse_mode=Default,
|
||||
self, text=None, reply_to=0, parse_mode=default,
|
||||
link_preview=None):
|
||||
"""
|
||||
Changes the draft message on the Telegram servers. The changes are
|
||||
|
@ -163,7 +164,7 @@ class Draft:
|
|||
|
||||
return result
|
||||
|
||||
async def send(self, clear=True, parse_mode=Default):
|
||||
async def send(self, clear=True, parse_mode=default):
|
||||
"""
|
||||
Sends the contents of this draft to the dialog. This is just a
|
||||
wrapper around ``send_message(dialog.input_entity, *args, **kwargs)``.
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
import hashlib
|
||||
|
||||
from .. import functions, types
|
||||
from ... import utils
|
||||
from ... import default, utils
|
||||
|
||||
|
||||
class InlineBuilder:
|
||||
|
@ -55,7 +55,7 @@ class InlineBuilder:
|
|||
async def article(
|
||||
self, title, description=None,
|
||||
*, url=None, thumb=None, content=None,
|
||||
id=None, text=None, parse_mode=utils.Default, link_preview=True,
|
||||
id=None, text=None, parse_mode=default, link_preview=True,
|
||||
geo=None, period=60, contact=None, game=False, buttons=None
|
||||
):
|
||||
"""
|
||||
|
@ -105,7 +105,7 @@ class InlineBuilder:
|
|||
|
||||
async def photo(
|
||||
self, file, *, id=None,
|
||||
text=None, parse_mode=utils.Default, link_preview=True,
|
||||
text=None, parse_mode=default, link_preview=True,
|
||||
geo=None, period=60, contact=None, game=False, buttons=None
|
||||
):
|
||||
"""
|
||||
|
@ -144,7 +144,7 @@ class InlineBuilder:
|
|||
self, file, title=None, *, description=None, type=None,
|
||||
mime_type=None, attributes=None, force_document=False,
|
||||
voice_note=False, video_note=False, use_cache=True, id=None,
|
||||
text=None, parse_mode=utils.Default, link_preview=True,
|
||||
text=None, parse_mode=default, link_preview=True,
|
||||
geo=None, period=60, contact=None, game=False, buttons=None
|
||||
):
|
||||
"""
|
||||
|
@ -219,7 +219,7 @@ class InlineBuilder:
|
|||
|
||||
async def game(
|
||||
self, short_name, *, id=None,
|
||||
text=None, parse_mode=utils.Default, link_preview=True,
|
||||
text=None, parse_mode=default, link_preview=True,
|
||||
geo=None, period=60, contact=None, game=False, buttons=None
|
||||
):
|
||||
"""
|
||||
|
@ -247,7 +247,7 @@ class InlineBuilder:
|
|||
|
||||
async def _message(
|
||||
self, *,
|
||||
text=None, parse_mode=utils.Default, link_preview=True,
|
||||
text=None, parse_mode=default, link_preview=True,
|
||||
geo=None, period=60, contact=None, game=False, buttons=None
|
||||
):
|
||||
if sum(1 for x in (text, geo, contact, game) if x) != 1:
|
||||
|
|
|
@ -46,14 +46,6 @@ VALID_USERNAME_RE = re.compile(
|
|||
)
|
||||
|
||||
|
||||
class Default:
|
||||
"""
|
||||
Sentinel value to indicate that the default value should be used.
|
||||
Currently used for the ``parse_mode``, where a ``None`` mode should
|
||||
be considered different from using the default.
|
||||
"""
|
||||
|
||||
|
||||
def chunks(iterable, size=100):
|
||||
"""
|
||||
Turns the given iterable into chunks of the specified size,
|
||||
|
@ -271,9 +263,41 @@ def get_input_photo(photo):
|
|||
if isinstance(photo, types.PhotoEmpty):
|
||||
return types.InputPhotoEmpty()
|
||||
|
||||
if isinstance(photo, types.messages.ChatFull):
|
||||
photo = photo.full_chat
|
||||
if isinstance(photo, types.ChannelFull):
|
||||
return get_input_photo(photo.chat_photo)
|
||||
elif isinstance(photo, types.UserFull):
|
||||
return get_input_photo(photo.profile_photo)
|
||||
elif isinstance(photo, (types.Channel, types.Chat, types.User)):
|
||||
return get_input_photo(photo.photo)
|
||||
|
||||
if isinstance(photo, (types.UserEmpty, types.ChatEmpty,
|
||||
types.ChatForbidden, types.ChannelForbidden)):
|
||||
return types.InputPhotoEmpty()
|
||||
|
||||
_raise_cast_fail(photo, 'InputPhoto')
|
||||
|
||||
|
||||
def get_input_chat_photo(photo):
|
||||
"""Similar to :meth:`get_input_peer`, but for chat photos"""
|
||||
try:
|
||||
if photo.SUBCLASS_OF_ID == 0xd4eb2d74: # crc32(b'InputChatPhoto')
|
||||
return photo
|
||||
elif photo.SUBCLASS_OF_ID == 0xe7655f1f: # crc32(b'InputFile'):
|
||||
return types.InputChatUploadedPhoto(photo)
|
||||
except AttributeError:
|
||||
_raise_cast_fail(photo, 'InputChatPhoto')
|
||||
|
||||
photo = get_input_photo(photo)
|
||||
if isinstance(photo, types.InputPhoto):
|
||||
return types.InputChatPhoto(photo)
|
||||
elif isinstance(photo, types.InputPhotoEmpty):
|
||||
return types.InputChatPhotoEmpty()
|
||||
|
||||
_raise_cast_fail(photo, 'InputChatPhoto')
|
||||
|
||||
|
||||
def get_input_geo(geo):
|
||||
"""Similar to :meth:`get_input_peer`, but for geo points"""
|
||||
try:
|
||||
|
|
|
@ -25,7 +25,9 @@ AUTO_CASTS = {
|
|||
'InputNotifyPeer': 'await client._get_input_notify({})',
|
||||
'InputMedia': 'utils.get_input_media({})',
|
||||
'InputPhoto': 'utils.get_input_photo({})',
|
||||
'InputMessage': 'utils.get_input_message({})'
|
||||
'InputMessage': 'utils.get_input_message({})',
|
||||
'InputDocument': 'utils.get_input_document({})',
|
||||
'InputChatPhoto': 'utils.get_input_chat_photo({})',
|
||||
}
|
||||
|
||||
NAMED_AUTO_CASTS = {
|
||||
|
|
|
@ -1,5 +0,0 @@
|
|||
from .crypto_test import CryptoTests
|
||||
from .network_test import NetworkTests
|
||||
from .parser_test import ParserTests
|
||||
from .tl_test import TLTests
|
||||
from .utils_test import UtilsTests
|
|
@ -1,143 +0,0 @@
|
|||
import unittest
|
||||
from hashlib import sha1
|
||||
|
||||
import telethon.helpers as utils
|
||||
from telethon.crypto import AES, Factorization
|
||||
# from crypto.PublicKey import RSA as PyCryptoRSA
|
||||
|
||||
|
||||
class CryptoTests(unittest.TestCase):
|
||||
def setUp(self):
|
||||
# Test known values
|
||||
self.key = b'\xd1\xf4MXy\x0c\xf8/z,\xe9\xf9\xa4\x17\x04\xd9C\xc9\xaba\x81\xf3\xf8\xdd\xcb\x0c6\x92\x01\x1f\xc2y'
|
||||
self.iv = b':\x02\x91x\x90Dj\xa6\x03\x90C\x08\x9e@X\xb5E\xffwy\xf3\x1c\xde\xde\xfbo\x8dm\xd6e.Z'
|
||||
|
||||
self.plain_text = b'Non encrypted text :D'
|
||||
self.plain_text_padded = b'My len is more uniform, promise!'
|
||||
|
||||
self.cipher_text = b'\xb6\xa7\xec.\xb9\x9bG\xcb\xe9{\x91[\x12\xfc\x84D\x1c' \
|
||||
b'\x93\xd9\x17\x03\xcd\xd6\xb1D?\x98\xd2\xb5\xa5U\xfd'
|
||||
|
||||
self.cipher_text_padded = b"W\xd1\xed'\x01\xa6c\xc3\xcb\xef\xaa\xe5\x1d\x1a" \
|
||||
b"[\x1b\xdf\xcdI\x1f>Z\n\t\xb9\xd2=\xbaF\xd1\x8e'"
|
||||
|
||||
def test_sha1(self):
|
||||
string = 'Example string'
|
||||
|
||||
hash_sum = sha1(string.encode('utf-8')).digest()
|
||||
expected = b'\nT\x92|\x8d\x06:)\x99\x04\x8e\xf8j?\xc4\x8e\xd3}m9'
|
||||
|
||||
self.assertEqual(hash_sum, expected,
|
||||
msg='Invalid sha1 hash_sum representation (should be {}, but is {})'
|
||||
.format(expected, hash_sum))
|
||||
|
||||
@unittest.skip("test_aes_encrypt needs fix")
|
||||
def test_aes_encrypt(self):
|
||||
value = AES.encrypt_ige(self.plain_text, self.key, self.iv)
|
||||
take = 16 # Don't take all the bytes, since latest involve are random padding
|
||||
self.assertEqual(value[:take], self.cipher_text[:take],
|
||||
msg='Ciphered text ("{}") does not equal expected ("{}")'
|
||||
.format(value[:take], self.cipher_text[:take]))
|
||||
|
||||
value = AES.encrypt_ige(self.plain_text_padded, self.key, self.iv)
|
||||
self.assertEqual(value, self.cipher_text_padded,
|
||||
msg='Ciphered text ("{}") does not equal expected ("{}")'
|
||||
.format(value, self.cipher_text_padded))
|
||||
|
||||
def test_aes_decrypt(self):
|
||||
# The ciphered text must always be padded
|
||||
value = AES.decrypt_ige(self.cipher_text_padded, self.key, self.iv)
|
||||
self.assertEqual(value, self.plain_text_padded,
|
||||
msg='Decrypted text ("{}") does not equal expected ("{}")'
|
||||
.format(value, self.plain_text_padded))
|
||||
|
||||
@unittest.skip("test_calc_key needs fix")
|
||||
def test_calc_key(self):
|
||||
# TODO Upgrade test for MtProto 2.0
|
||||
shared_key = b'\xbc\xd2m\xb7\xcav\xf4][\x88\x83\' \xf3\x11\x8as\xd04\x941\xae' \
|
||||
b'*O\x03\x86\x9a/H#\x1a\x8c\xb5j\xe9$\xe0IvCm^\xe70\x1a5C\t\x16' \
|
||||
b'\x03\xd2\x9d\xa9\x89\xd6\xce\x08P\x0fdr\xa0\xb3\xeb\xfecv\x1a' \
|
||||
b'\xdfJ\x14\x96\x98\x16\xa3G\xab\x04\x14!\\\xeb\n\xbcn\xdf\xc4%' \
|
||||
b'\xc6\t\xb7\x16\x14\x9c\'\x81\x15=\xb0\xaf\x0e\x0bR\xaa\x0466s' \
|
||||
b'\xf0\xcf\xb7\xb8>,D\x94x\xd7\xf8\xe0\x84\xcb%\xd3\x05\xb2\xe8' \
|
||||
b'\x95Mr?\xa2\xe8In\xf9\x0b[E\x9b\xaa\x0cX\x7f\x0ei\xde\xeed\x1d' \
|
||||
b'x/J\xce\xea^}0;\xa83B\xbbR\xa1\xbfe\x04\xb9\x1e\xa1"f=\xa5M@' \
|
||||
b'\x9e\xdd\x81\x80\xc9\xa5\xfb\xfcg\xdd\x15\x03p!\x0ffD\x16\x892' \
|
||||
b'\xea\xca\xb1A\x99O\xa94P\xa9\xa2\xc6;\xb2C9\x1dC5\xd2\r\xecL' \
|
||||
b'\xd9\xabw-\x03\ry\xc2v\x17]\x02\x15\x0cBa\x97\xce\xa5\xb1\xe4]' \
|
||||
b'\x8e\xe0,\xcfC{o\xfa\x99f\xa4pM\x00'
|
||||
|
||||
# Calculate key being the client
|
||||
msg_key = b'\xba\x1a\xcf\xda\xa8^Cbl\xfa\xb6\x0c:\x9b\xb0\xfc'
|
||||
|
||||
key, iv = utils.calc_key(shared_key, msg_key, client=True)
|
||||
expected_key = b"\xaf\xe3\x84Qm\xe0!\x0c\xd91\xe4\x9a\xa0v_gc" \
|
||||
b"x\xa1\xb0\xc9\xbc\x16'v\xcf,\x9dM\xae\xc6\xa5"
|
||||
|
||||
expected_iv = b'\xb8Q\xf3\xc5\xa3]\xc6\xdf\x9e\xe0Q\xbd"\x8d' \
|
||||
b'\x13\t\x0e\x9a\x9d^8\xa2\xf8\xe7\x00w\xd9\xc1' \
|
||||
b'\xa7\xa0\xf7\x0f'
|
||||
|
||||
self.assertEqual(key, expected_key,
|
||||
msg='Invalid key (expected ("{}"), got ("{}"))'
|
||||
.format(expected_key, key))
|
||||
self.assertEqual(iv, expected_iv,
|
||||
msg='Invalid IV (expected ("{}"), got ("{}"))'
|
||||
.format(expected_iv, iv))
|
||||
|
||||
# Calculate key being the server
|
||||
msg_key = b'\x86m\x92i\xcf\x8b\x93\xaa\x86K\x1fi\xd04\x83]'
|
||||
|
||||
key, iv = utils.calc_key(shared_key, msg_key, client=False)
|
||||
expected_key = b'\xdd0X\xb6\x93\x8e\xc9y\xef\x83\xf8\x8cj' \
|
||||
b'\xa7h\x03\xe2\xc6\xb16\xc5\xbb\xfc\xe7' \
|
||||
b'\xdf\xd6\xb1g\xf7u\xcfk'
|
||||
|
||||
expected_iv = b'\xdcL\xc2\x18\x01J"X\x86lb\xb6\xb547\xfd' \
|
||||
b'\xe2a4\xb6\xaf}FS\xd7[\xe0N\r\x19\xfb\xbc'
|
||||
|
||||
self.assertEqual(key, expected_key,
|
||||
msg='Invalid key (expected ("{}"), got ("{}"))'
|
||||
.format(expected_key, key))
|
||||
self.assertEqual(iv, expected_iv,
|
||||
msg='Invalid IV (expected ("{}"), got ("{}"))'
|
||||
.format(expected_iv, iv))
|
||||
|
||||
def test_generate_key_data_from_nonce(self):
|
||||
server_nonce = int.from_bytes(b'The 16-bit nonce', byteorder='little')
|
||||
new_nonce = int.from_bytes(b'The new, calculated 32-bit nonce', byteorder='little')
|
||||
|
||||
key, iv = utils.generate_key_data_from_nonce(server_nonce, new_nonce)
|
||||
expected_key = b'/\xaa\x7f\xa1\xfcs\xef\xa0\x99zh\x03M\xa4\x8e\xb4\xab\x0eE]b\x95|\xfe\xc0\xf8\x1f\xd4\xa0\xd4\xec\x91'
|
||||
expected_iv = b'\xf7\xae\xe3\xc8+=\xc2\xb8\xd1\xe1\x1b\x0e\x10\x07\x9fn\x9e\xdc\x960\x05\xf9\xea\xee\x8b\xa1h The '
|
||||
|
||||
self.assertEqual(key, expected_key,
|
||||
msg='Key ("{}") does not equal expected ("{}")'
|
||||
.format(key, expected_key))
|
||||
self.assertEqual(iv, expected_iv,
|
||||
msg='IV ("{}") does not equal expected ("{}")'
|
||||
.format(iv, expected_iv))
|
||||
|
||||
# test_fringerprint_from_key can't be skipped due to ImportError
|
||||
# def test_fingerprint_from_key(self):
|
||||
# assert rsa._compute_fingerprint(PyCryptoRSA.importKey(
|
||||
# '-----BEGIN RSA PUBLIC KEY-----\n'
|
||||
# 'MIIBCgKCAQEAwVACPi9w23mF3tBkdZz+zwrzKOaaQdr01vAbU4E1pvkfj4sqDsm6\n'
|
||||
# 'lyDONS789sVoD/xCS9Y0hkkC3gtL1tSfTlgCMOOul9lcixlEKzwKENj1Yz/s7daS\n'
|
||||
# 'an9tqw3bfUV/nqgbhGX81v/+7RFAEd+RwFnK7a+XYl9sluzHRyVVaTTveB2GazTw\n'
|
||||
# 'Efzk2DWgkBluml8OREmvfraX3bkHZJTKX4EQSjBbbdJ2ZXIsRrYOXfaA+xayEGB+\n'
|
||||
# '8hdlLmAjbCVfaigxX0CDqWeR1yFL9kwd9P0NsZRPsmoqVwMbMu7mStFai6aIhc3n\n'
|
||||
# 'Slv8kg9qv1m6XHVQY3PnEw+QQtqSIXklHwIDAQAB\n'
|
||||
# '-----END RSA PUBLIC KEY-----'
|
||||
# )) == b'!k\xe8l\x02+\xb4\xc3', 'Wrong fingerprint calculated'
|
||||
|
||||
def test_factorize(self):
|
||||
pq = 3118979781119966969
|
||||
p, q = Factorization.factorize(pq)
|
||||
if p > q:
|
||||
p, q = q, p
|
||||
|
||||
self.assertEqual(p, 1719614201,
|
||||
msg='Factorized pair did not yield the correct result')
|
||||
self.assertEqual(q, 1813767169,
|
||||
msg='Factorized pair did not yield the correct result')
|
|
@ -1,49 +0,0 @@
|
|||
import unittest
|
||||
import os
|
||||
from io import BytesIO
|
||||
from random import randint
|
||||
from hashlib import sha256
|
||||
from telethon import TelegramClient
|
||||
|
||||
# Fill in your api_id and api_hash when running the tests
|
||||
# and REMOVE THEM once you've finished testing them.
|
||||
api_id = None
|
||||
api_hash = None
|
||||
|
||||
|
||||
class HigherLevelTests(unittest.TestCase):
|
||||
def setUp(self):
|
||||
if not api_id or not api_hash:
|
||||
raise ValueError('Please fill in both your api_id and api_hash.')
|
||||
|
||||
@unittest.skip("you can't seriously trash random mobile numbers like that :)")
|
||||
def test_cdn_download(self):
|
||||
client = TelegramClient(None, api_id, api_hash)
|
||||
client.session.set_dc(0, '149.154.167.40', 80)
|
||||
self.assertTrue(client.connect())
|
||||
|
||||
try:
|
||||
phone = '+999662' + str(randint(0, 9999)).zfill(4)
|
||||
client.send_code_request(phone)
|
||||
client.sign_up('22222', 'Test', 'DC')
|
||||
|
||||
me = client.get_me()
|
||||
data = os.urandom(2 ** 17)
|
||||
client.send_file(
|
||||
me, data,
|
||||
progress_callback=lambda c, t:
|
||||
print('test_cdn_download:uploading {:.2%}...'.format(c/t))
|
||||
)
|
||||
msg = client.get_messages(me)[1][0]
|
||||
|
||||
out = BytesIO()
|
||||
client.download_media(msg, out)
|
||||
self.assertEqual(sha256(data).digest(), sha256(out.getvalue()).digest())
|
||||
|
||||
out = BytesIO()
|
||||
client.download_media(msg, out) # Won't redirect
|
||||
self.assertEqual(sha256(data).digest(), sha256(out.getvalue()).digest())
|
||||
|
||||
client.log_out()
|
||||
finally:
|
||||
client.disconnect()
|
|
@ -1,44 +0,0 @@
|
|||
import random
|
||||
import socket
|
||||
import threading
|
||||
import unittest
|
||||
|
||||
import telethon.network.authenticator as authenticator
|
||||
from telethon.extensions import TcpClient
|
||||
from telethon.network import Connection
|
||||
|
||||
|
||||
def run_server_echo_thread(port):
|
||||
def server_thread():
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
||||
s.bind(('', port))
|
||||
s.listen(1)
|
||||
connection, address = s.accept()
|
||||
with connection:
|
||||
data = connection.recv(16)
|
||||
connection.send(data)
|
||||
|
||||
server = threading.Thread(target=server_thread)
|
||||
server.start()
|
||||
|
||||
|
||||
class NetworkTests(unittest.TestCase):
|
||||
|
||||
@unittest.skip("test_tcp_client needs fix")
|
||||
def test_tcp_client(self):
|
||||
port = random.randint(50000, 60000) # Arbitrary non-privileged port
|
||||
run_server_echo_thread(port)
|
||||
|
||||
msg = b'Unit testing...'
|
||||
client = TcpClient()
|
||||
client.connect('localhost', port)
|
||||
client.write(msg)
|
||||
self.assertEqual(msg, client.read(15),
|
||||
msg='Read message does not equal sent message')
|
||||
client.close()
|
||||
|
||||
@unittest.skip("Some parameters changed, so IP doesn't go there anymore.")
|
||||
def test_authenticator(self):
|
||||
transport = Connection('149.154.167.91', 443)
|
||||
self.assertTrue(authenticator.do_authentication(transport))
|
||||
transport.close()
|
|
@ -1,8 +0,0 @@
|
|||
import unittest
|
||||
|
||||
|
||||
class ParserTests(unittest.TestCase):
|
||||
"""There are no tests yet"""
|
||||
@unittest.skip("there should be parser tests")
|
||||
def test_parser(self):
|
||||
self.assertTrue(True)
|
|
@ -1,8 +0,0 @@
|
|||
import unittest
|
||||
|
||||
|
||||
class TLTests(unittest.TestCase):
|
||||
"""There are no tests yet"""
|
||||
@unittest.skip("there should be TL tests")
|
||||
def test_tl(self):
|
||||
self.assertTrue(True)
|
|
@ -1,66 +0,0 @@
|
|||
import os
|
||||
import unittest
|
||||
from telethon.tl import TLObject
|
||||
from telethon.extensions import BinaryReader
|
||||
|
||||
|
||||
class UtilsTests(unittest.TestCase):
|
||||
def test_binary_writer_reader(self):
|
||||
# Test that we can read properly
|
||||
data = b'\x01\x05\x00\x00\x00\r\x00\x00\x00\x00\x00\x00\x00\x00\x00' \
|
||||
b'\x88A\x00\x00\x00\x00\x00\x009@\x1a\x1b\x1c\x1d\x1e\x1f ' \
|
||||
b'\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00' \
|
||||
b'\x00\x80'
|
||||
|
||||
with BinaryReader(data) as reader:
|
||||
value = reader.read_byte()
|
||||
self.assertEqual(value, 1,
|
||||
msg='Example byte should be 1 but is {}'.format(value))
|
||||
|
||||
value = reader.read_int()
|
||||
self.assertEqual(value, 5,
|
||||
msg='Example integer should be 5 but is {}'.format(value))
|
||||
|
||||
value = reader.read_long()
|
||||
self.assertEqual(value, 13,
|
||||
msg='Example long integer should be 13 but is {}'.format(value))
|
||||
|
||||
value = reader.read_float()
|
||||
self.assertEqual(value, 17.0,
|
||||
msg='Example float should be 17.0 but is {}'.format(value))
|
||||
|
||||
value = reader.read_double()
|
||||
self.assertEqual(value, 25.0,
|
||||
msg='Example double should be 25.0 but is {}'.format(value))
|
||||
|
||||
value = reader.read(7)
|
||||
self.assertEqual(value, bytes([26, 27, 28, 29, 30, 31, 32]),
|
||||
msg='Example bytes should be {} but is {}'
|
||||
.format(bytes([26, 27, 28, 29, 30, 31, 32]), value))
|
||||
|
||||
value = reader.read_large_int(128, signed=False)
|
||||
self.assertEqual(value, 2**127,
|
||||
msg='Example large integer should be {} but is {}'.format(2**127, value))
|
||||
|
||||
def test_binary_tgwriter_tgreader(self):
|
||||
small_data = os.urandom(33)
|
||||
small_data_padded = os.urandom(19) # +1 byte for length = 20 (%4 = 0)
|
||||
|
||||
large_data = os.urandom(999)
|
||||
large_data_padded = os.urandom(1024)
|
||||
|
||||
data = (small_data, small_data_padded, large_data, large_data_padded)
|
||||
string = 'Testing Telegram strings, this should work properly!'
|
||||
serialized = b''.join(TLObject.serialize_bytes(d) for d in data) + \
|
||||
TLObject.serialize_bytes(string)
|
||||
|
||||
with BinaryReader(serialized) as reader:
|
||||
# And then try reading it without errors (it should be unharmed!)
|
||||
for datum in data:
|
||||
value = reader.tgread_bytes()
|
||||
self.assertEqual(value, datum,
|
||||
msg='Example bytes should be {} but is {}'.format(datum, value))
|
||||
|
||||
value = reader.tgread_string()
|
||||
self.assertEqual(value, string,
|
||||
msg='Example string should be {} but is {}'.format(string, value))
|
Loading…
Reference in New Issue
Block a user