Merge branch 'master' into master

This commit is contained in:
Lonami 2018-10-05 20:07:16 +02:00 committed by GitHub
commit 1c6f1ac148
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
42 changed files with 1049 additions and 1237 deletions

View File

@ -1,24 +0,0 @@
#!/usr/bin/env python3
import unittest
if __name__ == '__main__':
from telethon_tests import \
CryptoTests, ParserTests, TLTests, UtilsTests, NetworkTests
test_classes = [CryptoTests, ParserTests, TLTests, UtilsTests]
network = input('Run network tests (y/n)?: ').lower() == 'y'
if network:
test_classes.append(NetworkTests)
loader = unittest.TestLoader()
suites_list = []
for test_class in test_classes:
suite = loader.loadTestsFromTestCase(test_class)
suites_list.append(suite)
big_suite = unittest.TestSuite(suites_list)
runner = unittest.TextTestRunner()
results = runner.run(big_suite)

View File

@ -2,7 +2,7 @@ import logging
from .client.telegramclient import TelegramClient
from .network import connection
from .tl import types, functions, custom
from . import version, events, utils, errors
from . import version, events, utils, errors, full_sync
__version__ = version.__version__

View File

@ -2,8 +2,8 @@ import itertools
import re
from .users import UserMethods
from .. import utils
from ..tl import types, custom
from .. import default, utils
from ..tl import types
class MessageParseMethods(UserMethods):
@ -62,7 +62,7 @@ class MessageParseMethods(UserMethods):
"""
Returns a (parsed message, entities) tuple depending on ``parse_mode``.
"""
if parse_mode == utils.Default:
if parse_mode == default:
parse_mode = self._parse_mode
else:
parse_mode = utils.sanitize_parse_mode(parse_mode)

View File

@ -8,8 +8,8 @@ from async_generator import async_generator, yield_
from .messageparse import MessageParseMethods
from .uploads import UploadMethods
from .buttons import ButtonMethods
from .. import utils, helpers
from ..tl import types, functions, custom
from .. import default, helpers, utils
from ..tl import types, functions
__log__ = logging.getLogger(__name__)
@ -360,7 +360,7 @@ class MessageMethods(UploadMethods, ButtonMethods, MessageParseMethods):
async def send_message(
self, entity, message='', *, reply_to=None,
parse_mode=utils.Default, link_preview=True, file=None,
parse_mode=default, link_preview=True, file=None,
force_document=False, clear_draft=False, buttons=None,
silent=None):
"""
@ -584,7 +584,7 @@ class MessageMethods(UploadMethods, ButtonMethods, MessageParseMethods):
async def edit_message(
self, entity, message=None, text=None,
*, parse_mode=utils.Default, link_preview=True, file=None,
*, parse_mode=default, link_preview=True, file=None,
buttons=None):
"""
Edits the given message ID (to change its contents or disable preview).

View File

@ -1,6 +1,5 @@
import abc
import asyncio
import collections
import inspect
import logging
import platform
@ -12,7 +11,6 @@ from .. import version
from ..crypto import rsa
from ..extensions import markdown
from ..network import MTProtoSender, ConnectionTcpFull
from ..network.mtprotostate import MTProtoState
from ..sessions import Session, SQLiteSession, MemorySession
from ..tl import TLObject, functions, types
from ..tl.alltlobjects import LAYER
@ -55,7 +53,7 @@ class TelegramBaseClient(abc.ABC):
connection (`telethon.network.connection.common.Connection`, optional):
The connection instance to be used when creating a new connection
to the servers. If it's a type, the `proxy` argument will be used.
to the servers. It **must** be a type.
Defaults to `telethon.network.connection.tcpfull.ConnectionTcpFull`.
@ -68,11 +66,11 @@ class TelegramBaseClient(abc.ABC):
A tuple consisting of ``(socks.SOCKS5, 'host', port)``.
See https://github.com/Anorov/PySocks#usage-1 for more.
timeout (`int` | `float` | `timedelta`, optional):
The timeout to be used when connecting, sending and receiving
responses from the network. This is **not** the timeout to
be used when ``await``'ing for invoked requests, and you
should use ``asyncio.wait`` or ``asyncio.wait_for`` for that.
timeout (`int` | `float`, optional):
The timeout in seconds to be used when connecting.
This is **not** the timeout to be used when ``await``'ing for
invoked requests, and you should use ``asyncio.wait`` or
``asyncio.wait_for`` for that.
request_retries (`int`, optional):
How many times a request should be retried. Request are retried
@ -150,7 +148,7 @@ class TelegramBaseClient(abc.ABC):
connection=ConnectionTcpFull,
use_ipv6=False,
proxy=None,
timeout=timedelta(seconds=10),
timeout=10,
request_retries=5,
connection_retries=5,
auto_reconnect=True,
@ -205,11 +203,12 @@ class TelegramBaseClient(abc.ABC):
self._request_retries = request_retries or sys.maxsize
self._connection_retries = connection_retries or sys.maxsize
self._proxy = proxy
self._timeout = timeout
self._auto_reconnect = auto_reconnect
if isinstance(connection, type):
connection = connection(
proxy=proxy, timeout=timeout, loop=self._loop)
assert isinstance(connection, type)
self._connection = connection
# Used on connection. Capture the variables in a lambda since
# exporting clients need to create this InvokeWithLayerRequest.
@ -227,12 +226,12 @@ class TelegramBaseClient(abc.ABC):
)
)
state = MTProtoState(self.session.auth_key)
self._connection = connection
self._sender = MTProtoSender(
state, connection, self._loop,
self._loop,
retries=self._connection_retries,
auto_reconnect=self._auto_reconnect,
connect_timeout=self._timeout,
update_callback=self._handle_update,
auth_key_callback=self._auth_key_callback,
auto_reconnect_callback=self._handle_auto_reconnect
@ -250,10 +249,6 @@ class TelegramBaseClient(abc.ABC):
# Save whether the user is authorized here (a.k.a. logged in)
self._authorized = None # None = We don't know yet
# Default PingRequest delay
self._last_ping = datetime.now()
self._ping_delay = timedelta(minutes=1)
self._updates_handle = None
self._last_request = time.time()
self._channel_pts = {}
@ -309,8 +304,10 @@ class TelegramBaseClient(abc.ABC):
"""
Connects to Telegram.
"""
await self._sender.connect(
self.session.server_address, self.session.port)
await self._sender.connect(self.session.auth_key, self._connection(
self.session.server_address, self.session.port,
loop=self._loop, proxy=self._proxy
))
await self._sender.send(self._init_with(
functions.help.GetConfigRequest()))
@ -373,7 +370,7 @@ class TelegramBaseClient(abc.ABC):
await self.session.set_dc(dc.id, dc.ip_address, dc.port)
# auth_key's are associated with a server, which has now changed
# so it's not valid anymore. Set to None to force recreating it.
self.session.auth_key = self._sender.state.auth_key = None
self.session.auth_key = None
await self.session.save()
await self._disconnect()
return await self.connect()
@ -416,13 +413,13 @@ class TelegramBaseClient(abc.ABC):
# Thanks badoualy/kotlogram on /telegram/api/DefaultTelegramClient.kt
# for clearly showing how to export the authorization
dc = await self._get_dc(dc_id)
state = MTProtoState(None)
# Can't reuse self._sender._connection as it has its own seqno.
#
# If one were to do that, Telegram would reset the connection
# with no further clues.
sender = MTProtoSender(state, self._connection.clone(), self._loop)
await sender.connect(dc.ip_address, dc.port)
sender = MTProtoSender(self._loop)
await sender.connect(None, self._connection(
dc.ip_address, dc.port, loop=self._loop, proxy=self._proxy))
__log__.info('Exporting authorization for data center %s', dc)
auth = await self(functions.auth.ExportAuthorizationRequest(dc_id))
req = self._init_with(functions.auth.ImportAuthorizationRequest(

View File

@ -5,11 +5,11 @@ import os
import pathlib
import re
from io import BytesIO
from mimetypes import guess_type
from .buttons import ButtonMethods
from .messageparse import MessageParseMethods
from .users import UserMethods
from .buttons import ButtonMethods
from .. import default
from .. import utils, helpers
from ..tl import types, functions, custom
@ -23,7 +23,7 @@ class UploadMethods(ButtonMethods, MessageParseMethods, UserMethods):
async def send_file(
self, entity, file, *, caption='', force_document=False,
progress_callback=None, reply_to=None, attributes=None,
thumb=None, allow_cache=True, parse_mode=utils.Default,
thumb=None, allow_cache=True, parse_mode=default,
voice_note=False, video_note=False, buttons=None, silent=None,
**kwargs):
"""
@ -180,7 +180,7 @@ class UploadMethods(ButtonMethods, MessageParseMethods, UserMethods):
async def _send_album(self, entity, files, caption='',
progress_callback=None, reply_to=None,
parse_mode=utils.Default, silent=None):
parse_mode=default, silent=None):
"""Specialized version of .send_file for albums"""
# We don't care if the user wants to avoid cache, we will use it
# anyway. Why? The cached version will be exactly the same thing

5
telethon/default.py Normal file
View File

@ -0,0 +1,5 @@
"""
Sentinel module to signify that a parameter should use its default value.
Useful when the default value or ``None`` are both valid options.
"""

View File

@ -124,7 +124,7 @@ class InlineQuery(EventBuilder):
async def answer(
self, results=None, cache_time=0, *,
gallery=False, private=False,
gallery=False, next_offset=None, private=False,
switch_pm=None, switch_pm_param=''):
"""
Answers the inline query with the given results.
@ -147,6 +147,10 @@ class InlineQuery(EventBuilder):
gallery (`bool`, optional):
Whether the results should show as a gallery (grid) or not.
next_offset (`str`, optional):
The offset the client will send when the user scrolls the
results and it repeats the request.
private (`bool`, optional):
Whether the results should be cached by Telegram
@ -163,11 +167,14 @@ class InlineQuery(EventBuilder):
if self._answered:
return
results = [self._as_awaitable(x, self._client.loop)
for x in results]
if results:
results = [self._as_awaitable(x, self._client.loop)
for x in results]
done, _ = await asyncio.wait(results, loop=self._client.loop)
results = [x.result() for x in done]
done, _ = await asyncio.wait(results, loop=self._client.loop)
results = [x.result() for x in done]
else:
results = []
if switch_pm:
switch_pm = types.InlineBotSwitchPM(switch_pm, switch_pm_param)
@ -178,6 +185,7 @@ class InlineQuery(EventBuilder):
results=results,
cache_time=cache_time,
gallery=gallery,
next_offset=next_offset,
private=private,
switch_pm=switch_pm
)

View File

@ -4,4 +4,3 @@ communication with support for cancelling the operation, and an utility class
to read arbitrary binary data in a more comfortable way, with int/strings/etc.
"""
from .binaryreader import BinaryReader
from .tcpclient import TcpClient

View File

@ -9,8 +9,8 @@ from html.parser import HTMLParser
from ..tl.types import (
MessageEntityBold, MessageEntityItalic, MessageEntityCode,
MessageEntityPre, MessageEntityEmail, MessageEntityUrl,
MessageEntityTextUrl
)
MessageEntityTextUrl, MessageEntityMentionName
)
# Helpers from markdown.py
@ -178,6 +178,9 @@ def unparse(text, entities):
elif entity_type == MessageEntityTextUrl:
html.append('<a href="{}">{}</a>'
.format(escape(entity.url), entity_text))
elif entity_type == MessageEntityMentionName:
html.append('<a href="tg://user?id={}">{}</a>'
.format(entity.user_id, entity_text))
else:
skip_entity = True
last_offset = entity.offset + (0 if skip_entity else entity.length)

View File

@ -9,8 +9,8 @@ from ..helpers import add_surrogate, del_surrogate
from ..tl import TLObject
from ..tl.types import (
MessageEntityBold, MessageEntityItalic, MessageEntityCode,
MessageEntityPre, MessageEntityTextUrl
)
MessageEntityPre, MessageEntityTextUrl, MessageEntityMentionName
)
DEFAULT_DELIMITERS = {
'**': MessageEntityBold,
@ -161,11 +161,17 @@ def unparse(text, entities, delimiters=None, url_fmt=None):
delimiter = delimiters.get(type(entity), None)
if delimiter:
text = text[:s] + delimiter + text[s:e] + delimiter + text[e:]
elif isinstance(entity, MessageEntityTextUrl) and url_fmt:
text = (
text[:s] +
add_surrogate(url_fmt.format(text[s:e], entity.url)) +
text[e:]
)
elif url_fmt:
url = None
if isinstance(entity, MessageEntityTextUrl):
url = entity.url
elif isinstance(entity, MessageEntityMentionName):
url = 'tg://user?id={}'.format(entity.user_id)
if url:
text = (
text[:s] +
add_surrogate(url_fmt.format(text[s:e], url)) +
text[e:]
)
return del_surrogate(text)

View File

@ -1,171 +0,0 @@
"""
This module holds a rough implementation of the C# TCP client.
This class is **not** safe across several tasks since partial reads
may be ``await``'ed before being able to return the exact byte count.
This class is also not concerned about disconnections or retries of
any sort, nor any other kind of errors such as connecting twice.
"""
import asyncio
import errno
import logging
import socket
import ssl
CONN_RESET_ERRNOS = {
errno.EBADF, errno.ENOTSOCK, errno.ENETUNREACH,
errno.EINVAL, errno.ENOTCONN, errno.EHOSTUNREACH,
errno.ECONNREFUSED, errno.ECONNRESET, errno.ECONNABORTED,
errno.ENETDOWN, errno.ENETRESET, errno.ECONNABORTED,
errno.EHOSTDOWN, errno.EPIPE, errno.ESHUTDOWN
}
# catched: EHOSTUNREACH, ECONNREFUSED, ECONNRESET, ENETUNREACH
# ConnectionError: EPIPE, ESHUTDOWN, ECONNABORTED, ECONNREFUSED, ECONNRESET
try:
import socks
except ImportError:
socks = None
SSL_PORT = 443
__log__ = logging.getLogger(__name__)
class TcpClient:
"""A simple TCP client to ease the work with sockets and proxies."""
class SocketClosed(ConnectionError):
pass
def __init__(self, *, loop, timeout, ssl=None, proxy=None):
"""
Initializes the TCP client.
:param proxy: the proxy to be used, if any.
:param timeout: the timeout for connect, read and write operations.
:param ssl: ssl.wrap_socket keyword arguments to use when connecting
if port == SSL_PORT, or do nothing if not present.
"""
self._loop = loop
self.proxy = proxy
self.ssl = ssl
self._socket = None
self._reader = None
self._writer = None
self._closed = asyncio.Event(loop=self._loop)
self._closed.set()
if isinstance(timeout, (int, float)):
self.timeout = float(timeout)
elif hasattr(timeout, 'seconds'):
self.timeout = float(timeout.seconds)
else:
raise TypeError('Invalid timeout type: {}'.format(type(timeout)))
@staticmethod
def _create_socket(mode, proxy):
if proxy is None:
s = socket.socket(mode, socket.SOCK_STREAM)
else:
__log__.info('Connection will be made through proxy %s', proxy)
import socks
s = socks.socksocket(mode, socket.SOCK_STREAM)
if isinstance(proxy, dict):
s.set_proxy(**proxy)
else: # tuple, list, etc.
s.set_proxy(*proxy)
s.setblocking(False)
return s
async def connect(self, ip, port):
"""
Tries connecting to IP:port unless an OSError is raised.
:param ip: the IP to connect to.
:param port: the port to connect to.
"""
if ':' in ip: # IPv6
ip = ip.replace('[', '').replace(']', '')
mode, address = socket.AF_INET6, (ip, port, 0, 0)
else:
mode, address = socket.AF_INET, (ip, port)
try:
if self._socket is None:
self._socket = self._create_socket(mode, self.proxy)
wrap_ssl = self.ssl and port == SSL_PORT
else:
wrap_ssl = False
await asyncio.wait_for(
self._loop.sock_connect(self._socket, address),
timeout=self.timeout,
loop=self._loop
)
if wrap_ssl:
# Temporarily set the socket to blocking
# (timeout) until connection is established.
self._socket.settimeout(self.timeout)
self._socket = ssl.wrap_socket(
self._socket, do_handshake_on_connect=True, **self.ssl)
self._socket.setblocking(False)
self._closed.clear()
self._reader, self._writer =\
await asyncio.open_connection(sock=self._socket)
except OSError as e:
if e.errno in CONN_RESET_ERRNOS:
raise ConnectionResetError() from e
else:
raise
@property
def is_connected(self):
"""Determines whether the client is connected or not."""
return not self._closed.is_set()
def close(self):
"""Closes the connection."""
fd = None
try:
if self._writer is not None:
self._writer.close()
if self._socket is not None:
fd = self._socket.fileno()
if self.is_connected:
self._socket.shutdown(socket.SHUT_RDWR)
self._socket.close()
except OSError:
pass # Ignore ENOTCONN, EBADF, and any other error when closing
finally:
self._socket = None
self._reader = None
self._writer = None
self._closed.set()
if fd and fd != -1:
self._loop.remove_reader(fd)
async def write(self, data):
"""
Writes (sends) the specified bytes to the connected peer.
:param data: the data to send.
"""
if not self.is_connected:
raise ConnectionResetError('Not connected')
self._writer.write(data)
await self._writer.drain()
async def read(self, size):
"""
Reads (receives) a whole block of size bytes from the connected peer.
:param size: the size of the block to be read.
:return: the read data with len(data) == size.
"""
if not self.is_connected:
raise ConnectionResetError('Not connected')
return await self._reader.readexactly(size)

161
telethon/full_sync.py Normal file
View File

@ -0,0 +1,161 @@
"""
This magical module will rewrite all public methods in the public interface of
the library so they can delegate the call to an asyncio event loop in another
thread and wait for the result. This rewrite may not be desirable if the end
user always uses the methods they way they should be ran, but it's incredibly
useful for quick scripts and legacy code.
"""
import asyncio
import functools
import inspect
import threading
from concurrent.futures import Future, ThreadPoolExecutor
from async_generator import isasyncgenfunction
from .client.telegramclient import TelegramClient
from .tl.custom import (
Draft, Dialog, MessageButton, Forward, Message, InlineResult, Conversation
)
from .tl.custom.chatgetter import ChatGetter
from .tl.custom.sendergetter import SenderGetter
async def _proxy_future(af, cf):
try:
res = await af
cf.set_result(res)
except Exception as e:
cf.set_exception(e)
def _sync_result(loop, x):
f = Future()
loop.call_soon_threadsafe(asyncio.ensure_future, _proxy_future(x, f))
return f.result()
class _SyncGen:
def __init__(self, loop, gen):
self.loop = loop
self.gen = gen
def __iter__(self):
return self
def __next__(self):
try:
return _sync_result(self.loop, self.gen.__anext__())
except StopAsyncIteration:
raise StopIteration from None
def _syncify_wrap(t, method_name, loop, thread_ident, syncifier=_sync_result):
method = getattr(t, method_name)
@functools.wraps(method)
def syncified(*args, **kwargs):
coro = method(*args, **kwargs)
return (
coro if threading.get_ident() == thread_ident
else syncifier(loop, coro)
)
setattr(t, method_name, syncified)
def _syncify(*types, loop, thread_ident):
for t in types:
for method_name in dir(t):
if not method_name.startswith('_') or method_name == '__call__':
if inspect.iscoroutinefunction(getattr(t, method_name)):
_syncify_wrap(t, method_name, loop, thread_ident, _sync_result)
elif isasyncgenfunction(getattr(t, method_name)):
_syncify_wrap(t, method_name, loop, thread_ident, _SyncGen)
__asyncthread = None
def enable(*, loop=None, executor=None, max_workers=1):
"""
Enables the fully synchronous mode. You should enable this at
the beginning of your script, right after importing, only once.
**Please** make sure to call `stop` at the end of your script.
You can define the event loop to use and executor, otherwise
the default loop and ``ThreadPoolExecutor`` will be used, in
which case `max_workers` will be passed to it. If you pass a
custom executor, `max_workers` will be ignored.
"""
global __asyncthread
if __asyncthread is not None:
raise RuntimeError("full_sync can only be enabled once")
if not loop:
loop = asyncio.get_event_loop()
if not executor:
executor = ThreadPoolExecutor(max_workers=max_workers)
def start():
asyncio.set_event_loop(loop)
loop.run_forever()
__asyncthread = threading.Thread(
target=start, name="__telethon_async_thread__", daemon=True
)
__asyncthread.start()
__asyncthread.loop = loop
__asyncthread.executor = executor
TelegramClient.__init__ = functools.partialmethod(
TelegramClient.__init__, loop=loop
)
_syncify(TelegramClient, Draft, Dialog, MessageButton, ChatGetter,
SenderGetter, Forward, Message, InlineResult, Conversation,
loop=loop, thread_ident=__asyncthread.ident)
_syncify_wrap(TelegramClient, "start", loop, __asyncthread.ident)
old_add_event_handler = TelegramClient.add_event_handler
old_remove_event_handler = TelegramClient.remove_event_handler
proxied_event_handlers = {}
@functools.wraps(old_add_event_handler)
def add_proxied_event_handler(self, callback, *args, **kwargs):
async def _proxy(*pargs, **pkwargs):
await loop.run_in_executor(
executor, functools.partial(callback, *pargs, **pkwargs))
proxied_event_handlers[callback] = _proxy
args = (self, _proxy, *args)
return old_add_event_handler(*args, **kwargs)
@functools.wraps(old_remove_event_handler)
def remove_proxied_event_handler(self, callback, *args, **kwargs):
args = (self, proxied_event_handlers.get(callback, callback), *args)
return old_remove_event_handler(*args, **kwargs)
TelegramClient.add_event_handler = add_proxied_event_handler
TelegramClient.remove_event_handler = remove_proxied_event_handler
def run_until_disconnected(self):
return _sync_result(loop, self._run_until_disconnected())
TelegramClient.run_until_disconnected = run_until_disconnected
return __asyncthread
def stop():
"""
Stops the fully synchronous code. You
should call this before your script exits.
"""
global __asyncthread
if not __asyncthread:
raise RuntimeError("Can't find asyncio thread")
__asyncthread.loop.call_soon_threadsafe(__asyncthread.loop.stop)
__asyncthread.executor.shutdown()

View File

@ -1,5 +1,5 @@
"""Various helpers not related to the Telegram API itself"""
import collections
import asyncio
import os
import struct
from hashlib import sha1, sha256
@ -87,4 +87,46 @@ class TotalList(list):
return '[{}, total={}]'.format(
', '.join(repr(x) for x in self), self.total)
class _ReadyQueue:
"""
A queue list that supports an arbitrary cancellation token for `get`.
"""
def __init__(self, loop):
self._list = []
self._loop = loop
self._ready = asyncio.Event(loop=loop)
def append(self, item):
self._list.append(item)
self._ready.set()
def extend(self, items):
self._list.extend(items)
self._ready.set()
async def get(self, cancellation):
"""
Returns a list of all the items added to the queue until now and
clears the list from the queue itself. Returns ``None`` if cancelled.
"""
ready = self._loop.create_task(self._ready.wait())
try:
done, pending = await asyncio.wait(
[ready, cancellation],
return_when=asyncio.FIRST_COMPLETED,
loop=self._loop
)
except asyncio.CancelledError:
done = [cancellation]
if cancellation in done:
ready.cancel()
return None
result = self._list
self._list = []
self._ready.clear()
return result
# endregion

View File

@ -1,74 +0,0 @@
"""
This module holds the abstract `Connection` class.
The `Connection.send` and `Connection.recv` methods need **not** to be
safe across several tasks and may use any amount of ``await`` keywords.
The code using these `Connection`'s should be responsible for using
an ``async with asyncio.Lock:`` block when calling said methods.
Said subclasses need not to worry about reconnecting either, and
should let the errors propagate instead.
"""
import abc
class Connection(abc.ABC):
"""
Represents an abstract connection for Telegram.
Subclasses should implement the actual protocol
being used when encoding/decoding messages.
"""
def __init__(self, *, loop, timeout, proxy=None):
"""
Initializes a new connection.
:param loop: the event loop to be used.
:param timeout: timeout to be used for all operations.
:param proxy: whether to use a proxy or not.
"""
self._loop = loop
self._proxy = proxy
self._timeout = timeout
@abc.abstractmethod
async def connect(self, ip, port):
raise NotImplementedError
@abc.abstractmethod
def get_timeout(self):
"""Returns the timeout used by the connection."""
raise NotImplementedError
@abc.abstractmethod
def is_connected(self):
"""
Determines whether the connection is alive or not.
:return: true if it's connected.
"""
raise NotImplementedError
@abc.abstractmethod
async def close(self):
"""Closes the connection."""
raise NotImplementedError
def clone(self):
"""Creates a copy of this Connection."""
return self.__class__(
loop=self._loop,
proxy=self._proxy,
timeout=self._timeout
)
@abc.abstractmethod
async def recv(self):
"""Receives and unpacks a message"""
raise NotImplementedError
@abc.abstractmethod
async def send(self, message):
"""Encapsulates and sends the given message"""
raise NotImplementedError

View File

@ -0,0 +1,180 @@
import abc
import asyncio
import logging
import socket
import ssl as ssl_mod
__log__ = logging.getLogger(__name__)
class Connection(abc.ABC):
"""
The `Connection` class is a wrapper around ``asyncio.open_connection``.
Subclasses are meant to communicate with this class through a queue.
This class provides a reliable interface that will stay connected
under any conditions for as long as the user doesn't disconnect or
the input parameters to auto-reconnect dictate otherwise.
"""
def __init__(self, ip, port, *, loop, proxy=None):
self._ip = ip
self._port = port
self._loop = loop
self._proxy = proxy
self._reader = None
self._writer = None
self._disconnected = asyncio.Event(loop=loop)
self._disconnected.set()
self._disconnected_future = None
self._send_task = None
self._recv_task = None
self._send_queue = asyncio.Queue(1)
self._recv_queue = asyncio.Queue(1)
async def connect(self, timeout=None, ssl=None):
"""
Establishes a connection with the server.
"""
if not self._proxy:
self._reader, self._writer = await asyncio.wait_for(
asyncio.open_connection(
self._ip, self._port, loop=self._loop, ssl=ssl),
loop=self._loop, timeout=timeout
)
else:
import socks
if ':' in self._ip:
mode, address = socket.AF_INET6, (self._ip, self._port, 0, 0)
else:
mode, address = socket.AF_INET, (self._ip, self._port)
s = socks.socksocket(mode, socket.SOCK_STREAM)
if isinstance(self._proxy, dict):
s.set_proxy(**self._proxy)
else:
s.set_proxy(*self._proxy)
s.setblocking(False)
await asyncio.wait_for(
self._loop.sock_connect(s, address),
timeout=timeout,
loop=self._loop
)
if ssl:
self._socket.settimeout(timeout)
self._socket = ssl_mod.wrap_socket(
s,
do_handshake_on_connect=True,
ssl_version=ssl_mod.PROTOCOL_SSLv23,
ciphers='ADH-AES256-SHA'
)
self._socket.setblocking(False)
self._reader, self._writer = await asyncio.open_connection(
self._ip, self._port, loop=self._loop, sock=s
)
self._disconnected.clear()
self._disconnected_future = None
self._send_task = self._loop.create_task(self._send_loop())
self._recv_task = self._loop.create_task(self._recv_loop())
def disconnect(self):
"""
Disconnects from the server.
"""
self._disconnected.set()
if self._send_task:
self._send_task.cancel()
if self._recv_task:
self._recv_task.cancel()
if self._writer:
self._writer.close()
@property
def disconnected(self):
if not self._disconnected_future:
self._disconnected_future = \
self._loop.create_task(self._disconnected.wait())
return self._disconnected_future
def clone(self):
"""
Creates a clone of the connection.
"""
return self.__class__(self._ip, self._port, loop=self._loop)
def send(self, data):
"""
Sends a packet of data through this connection mode.
This method returns a coroutine.
"""
return self._send_queue.put(data)
async def recv(self):
"""
Receives a packet of data through this connection mode.
This method returns a coroutine.
"""
ok, result = await self._recv_queue.get()
if ok:
return result
else:
raise result from None
async def _send_loop(self):
"""
This loop is constantly popping items off the queue to send them.
"""
try:
while not self._disconnected.is_set():
self._send(await self._send_queue.get())
await self._writer.drain()
except asyncio.CancelledError:
pass
except Exception:
logging.exception('Unhandled exception in the sending loop')
self.disconnect()
async def _recv_loop(self):
"""
This loop is constantly putting items on the queue as they're read.
"""
try:
while not self._disconnected.is_set():
data = await self._recv()
await self._recv_queue.put((True, data))
except asyncio.CancelledError:
pass
except Exception as e:
await self._recv_queue.put((False, e))
self.disconnect()
@abc.abstractmethod
def _send(self, data):
"""
This method should be implemented differently under each
connection mode and serialize the data into the packet
the way it should be sent through `self._writer`.
"""
raise NotImplementedError
@abc.abstractmethod
async def _recv(self):
"""
This method should be implemented differently under each
connection mode and deserialize the data from the packet
the way it should be read from `self._reader`.
"""
raise NotImplementedError
def __str__(self):
return '{}:{}/{}'.format(
self._ip, self._port,
self.__class__.__name__.replace('Connection', '')
)

View File

@ -1,62 +1,34 @@
import errno
import ssl
import asyncio
from .common import Connection
from ...extensions import TcpClient
from .connection import Connection
SSL_PORT = 443
class ConnectionHttp(Connection):
def __init__(self, *, loop, timeout, proxy=None):
super().__init__(loop=loop, timeout=timeout, proxy=proxy)
self.conn = TcpClient(
timeout=self._timeout, loop=self._loop, proxy=self._proxy,
ssl=dict(ssl_version=ssl.PROTOCOL_SSLv23, ciphers='ADH-AES256-SHA')
)
self.read = self.conn.read
self.write = self.conn.write
self._host = None
async def connect(self, timeout=None, ssl=None):
await super().connect(timeout=timeout, ssl=self._port == SSL_PORT)
async def connect(self, ip, port):
self._host = '{}:{}'.format(ip, port)
try:
await self.conn.connect(ip, port)
except OSError as e:
if e.errno == errno.EISCONN:
return # Already connected, no need to re-set everything up
else:
raise
def get_timeout(self):
return self.conn.timeout
def is_connected(self):
return self.conn.is_connected
async def close(self):
self.conn.close()
async def recv(self):
while True:
line = await self._read_line()
if line.lower().startswith(b'content-length: '):
await self.read(2)
length = int(line[16:-2])
return await self.read(length)
async def _read_line(self):
newline = ord('\n')
line = await self.read(1)
while line[-1] != newline:
line += await self.read(1)
return line
async def send(self, message):
await self.write(
def _send(self, message):
self._writer.write(
'POST /api HTTP/1.1\r\n'
'Host: {}\r\n'
'Host: {}:{}\r\n'
'Content-Type: application/x-www-form-urlencoded\r\n'
'Connection: keep-alive\r\n'
'Keep-Alive: timeout=100000, max=10000000\r\n'
'Content-Length: {}\r\n\r\n'.format(self._host, len(message))
'Content-Length: {}\r\n\r\n'
.format(self._ip, self._port, len(message))
.encode('ascii') + message
)
async def _recv(self):
while True:
line = await self._reader.readline()
if not line or line[-1] != b'\n':
raise asyncio.IncompleteReadError(line, None)
if line.lower().startswith(b'content-length: '):
await self._reader.readexactly(2)
length = int(line[16:-2])
return await self._reader.readexactly(length)

View File

@ -1,31 +1,44 @@
import struct
from .tcpfull import ConnectionTcpFull
from .connection import Connection
class ConnectionTcpAbridged(ConnectionTcpFull):
class ConnectionTcpAbridged(Connection):
"""
This is the mode with the lowest overhead, as it will
only require 1 byte if the packet length is less than
508 bytes (127 << 2, which is very common).
"""
async def connect(self, ip, port):
result = await super().connect(ip, port)
await self.conn.write(b'\xef')
return result
async def connect(self, timeout=None, ssl=None):
await super().connect(timeout=timeout, ssl=ssl)
self._writer.write(b'\xef')
await self._writer.drain()
async def recv(self):
length = struct.unpack('<B', await self.read(1))[0]
if length >= 127:
length = struct.unpack('<i', await self.read(3) + b'\0')[0]
def _write(self, data):
"""
Define wrapper write methods for `TcpObfuscated` to override.
"""
self._writer.write(data)
return await self.read(length << 2)
async def _read(self, n):
"""
Define wrapper read methods for `TcpObfuscated` to override.
"""
return await self._reader.readexactly(n)
async def send(self, message):
length = len(message) >> 2
def _send(self, data):
length = len(data) >> 2
if length < 127:
length = struct.pack('B', length)
else:
length = b'\x7f' + int.to_bytes(length, 3, 'little')
await self.write(length + message)
self._write(length + data)
async def _recv(self):
length = struct.unpack('<B', await self._read(1))[0]
if length >= 127:
length = struct.unpack(
'<i', await self._read(3) + b'\0')[0]
return await self._read(length << 2)

View File

@ -1,10 +1,8 @@
import errno
import struct
from zlib import crc32
from .common import Connection
from .connection import Connection
from ...errors import InvalidChecksumError
from ...extensions import TcpClient
class ConnectionTcpFull(Connection):
@ -12,39 +10,27 @@ class ConnectionTcpFull(Connection):
Default Telegram mode. Sends 12 additional bytes and
needs to calculate the CRC value of the packet itself.
"""
def __init__(self, *, loop, timeout, proxy=None):
super().__init__(loop=loop, timeout=timeout, proxy=proxy)
self._send_counter = 0
self.conn = TcpClient(
timeout=self._timeout, loop=self._loop, proxy=self._proxy
)
self.read = self.conn.read
self.write = self.conn.write
async def connect(self, ip, port):
try:
await self.conn.connect(ip, port)
except OSError as e:
if e.errno == errno.EISCONN:
return # Already connected, no need to re-set everything up
else:
raise
def __init__(self, ip, port, *, loop, proxy=None):
super().__init__(ip, port, loop=loop, proxy=proxy)
self._send_counter = 0
def get_timeout(self):
return self.conn.timeout
async def connect(self, timeout=None, ssl=None):
await super().connect(timeout=timeout, ssl=ssl)
self._send_counter = 0 # Important or Telegram won't reply
def is_connected(self):
return self.conn.is_connected
def _send(self, data):
# https://core.telegram.org/mtproto#tcp-transport
# total length, sequence number, packet and checksum (CRC32)
length = len(data) + 12
data = struct.pack('<ii', length, self._send_counter) + data
crc = struct.pack('<I', crc32(data))
self._send_counter += 1
self._writer.write(data + crc)
async def close(self):
self.conn.close()
async def recv(self):
packet_len_seq = await self.read(8) # 4 and 4
async def _recv(self):
packet_len_seq = await self._reader.readexactly(8) # 4 and 4
packet_len, seq = struct.unpack('<ii', packet_len_seq)
body = await self.read(packet_len - 8)
body = await self._reader.readexactly(packet_len - 8)
checksum = struct.unpack('<I', body[-4:])[0]
body = body[:-4]
@ -53,12 +39,3 @@ class ConnectionTcpFull(Connection):
raise InvalidChecksumError(checksum, valid_checksum)
return body
async def send(self, message):
# https://core.telegram.org/mtproto#tcp-transport
# total length, sequence number, packet and checksum (CRC32)
length = len(message) + 12
data = struct.pack('<ii', length, self._send_counter) + message
crc = struct.pack('<I', crc32(data))
self._send_counter += 1
await self.write(data + crc)

View File

@ -1,20 +1,21 @@
import struct
from .tcpfull import ConnectionTcpFull
from .connection import Connection
class ConnectionTcpIntermediate(ConnectionTcpFull):
class ConnectionTcpIntermediate(Connection):
"""
Intermediate mode between `ConnectionTcpFull` and `ConnectionTcpAbridged`.
Always sends 4 extra bytes for the packet length.
"""
async def connect(self, ip, port):
result = await super().connect(ip, port)
await self.conn.write(b'\xee\xee\xee\xee')
return result
async def connect(self, timeout=None, ssl=None):
await super().connect(timeout=timeout, ssl=ssl)
self._writer.write(b'\xee\xee\xee\xee')
await self._writer.drain()
async def recv(self):
return await self.read(struct.unpack('<i', await self.read(4))[0])
def _send(self, data):
self._writer.write(struct.pack('<i', len(data)) + data)
async def send(self, message):
await self.write(struct.pack('<i', len(message)) + message)
async def _recv(self):
return await self._reader.readexactly(
struct.unpack('<i', await self._reader.readexactly(4))[0])

View File

@ -1,7 +1,6 @@
import os
from .tcpabridged import ConnectionTcpAbridged
from .tcpfull import ConnectionTcpFull
from ...crypto import AESModeCTR
@ -11,16 +10,22 @@ class ConnectionTcpObfuscated(ConnectionTcpAbridged):
every message with a randomly generated key using the
AES-CTR mode so the packets are harder to discern.
"""
def __init__(self, *, loop, timeout, proxy=None):
super().__init__(loop=loop, timeout=timeout, proxy=proxy)
self._aes_encrypt, self._aes_decrypt = None, None
self.read = lambda s: self._aes_decrypt.encrypt(self.conn.read(s))
self.write = lambda d: self.conn.write(self._aes_encrypt.encrypt(d))
def __init__(self, ip, port, *, loop, proxy=None):
super().__init__(ip, port, loop=loop, proxy=proxy)
self._aes_encrypt = None
self._aes_decrypt = None
def _write(self, data):
self._writer.write(self._aes_encrypt.encrypt(data))
async def _read(self, n):
return self._aes_decrypt.encrypt(await self._reader.readexactly(n))
async def connect(self, timeout=None, ssl=None):
await super().connect(timeout=timeout, ssl=ssl)
async def connect(self, ip, port):
result = await ConnectionTcpFull.connect(self, ip, port)
# Obfuscated messages secrets cannot start with any of these
keywords = (b'PVrG', b'GET ', b'POST', b'\xee' * 4)
keywords = (b'PVrG', b'GET ', b'POST', b'\xee\xee\xee\xee')
while True:
random = os.urandom(64)
if (random[0] != b'\xef' and
@ -28,11 +33,11 @@ class ConnectionTcpObfuscated(ConnectionTcpAbridged):
random[4:4] != b'\0\0\0\0'):
break
random = list(random)
random = bytearray(random)
random[56] = random[57] = random[58] = random[59] = 0xef
random_reversed = random[55:7:-1] # Reversed (8, len=48)
# encryption has "continuous buffer" enabled
# Encryption has "continuous buffer" enabled
encrypt_key = bytes(random[8:40])
encrypt_iv = bytes(random[40:56])
decrypt_key = bytes(random_reversed[:32])
@ -42,5 +47,5 @@ class ConnectionTcpObfuscated(ConnectionTcpAbridged):
self._aes_decrypt = AESModeCTR(decrypt_key, decrypt_iv)
random[56:64] = self._aes_encrypt.encrypt(bytes(random))[56:64]
await self.conn.write(bytes(random))
return result
self._writer.write(random)
await self._writer.drain()

View File

@ -0,0 +1,158 @@
import io
import logging
import struct
from .mtprotostate import MTProtoState
from ..tl import TLRequest
from ..tl.core.tlmessage import TLMessage
from ..tl.core.messagecontainer import MessageContainer
__log__ = logging.getLogger(__name__)
class MTProtoLayer:
"""
This class is the message encryption layer between the methods defined
in the schema and the response objects. It also holds the necessary state
necessary for this encryption to happen.
The `connection` parameter is through which these messages will be sent
and received.
The `auth_key` must be a valid authorization key which will be used to
encrypt these messages. This class is not responsible for generating them.
"""
def __init__(self, connection, auth_key):
self._connection = connection
self._state = MTProtoState(auth_key)
def connect(self, timeout=None):
"""
Wrapper for ``self._connection.connect()``.
"""
return self._connection.connect(timeout=timeout)
def disconnect(self):
"""
Wrapper for ``self._connection.disconnect()``.
"""
self._connection.disconnect()
def reset_state(self):
self._state = MTProtoState(self._state.auth_key)
async def send(self, state_list):
"""
The list of `RequestState` that will be sent. They will
be updated with their new message and container IDs.
Nested lists imply an order is required for the messages in them.
Message containers will be used if there is more than one item.
"""
for data in filter(None, self._pack_state_list(state_list)):
await self._connection.send(self._state.encrypt_message_data(data))
async def recv(self):
"""
Reads a single message from the network, decrypts it and returns it.
"""
body = await self._connection.recv()
return self._state.decrypt_message_data(body)
def _pack_state_list(self, state_list):
"""
The list of `RequestState` that will be sent. They will
be updated with their new message and container IDs.
Packs all their serialized data into a message (possibly
nested inside another message and message container) and
returns the serialized message data.
"""
# Note that the simplest case is writing a single query data into
# a message, and returning the message data and ID. For efficiency
# purposes this method supports more than one message and automatically
# uses containers if deemed necessary.
#
# Technically the message and message container classes could be used
# to store and serialize the data. However, to keep the context local
# and relevant to the only place where such feature is actually used,
# this is not done.
#
# When iterating over the state_list there are two branches, one
# being just a state and the other being a list so the inner states
# depend on each other. In either case, if the packed size exceeds
# the maximum container size, it must be sent. This code is non-
# trivial so it has been factored into an inner function.
#
# A new buffer instance will be used once the size should be "flushed"
buffer = io.BytesIO()
# The batch of requests sent in a single buffer-flush. We need to
# remember which states were written to set their container ID.
batch = []
# The currently written size. Reset when it exceeds the maximum.
size = 0
def write_state(state, after_id=None):
nonlocal buffer, batch, size
if state:
batch.append(state)
size += len(state.data) + TLMessage.SIZE_OVERHEAD
# Flush whenever the current size exceeds the maximum,
# or if there's no state, which indicates force flush.
if not state or size > MessageContainer.MAXIMUM_SIZE:
size -= MessageContainer.MAXIMUM_SIZE
if len(batch) > 1:
# Inlined code to pack several messages into a container
data = struct.pack(
'<Ii', MessageContainer.CONSTRUCTOR_ID, len(batch)
) + buffer.getvalue()
buffer = io.BytesIO()
container_id = self._state.write_data_as_message(
buffer, data, content_related=False
)
for s in batch:
s.container_id = container_id
# At this point it's either a single msg or a msg + container
data = buffer.getvalue()
__log__.debug('Packed %d message(s) in %d bytes for sending',
len(batch), len(data))
batch.clear()
buffer = io.BytesIO()
return data
if not state:
return # Just forcibly flushing
# If even after flushing it still exceeds the maximum size,
# this message payload cannot be sent. Telegram would forcibly
# close the connection, and the message would never be confirmed.
if size > MessageContainer.MAXIMUM_SIZE:
state.future.set_exception(
ValueError('Request payload is too big'))
return
# This is the only requirement to make this work.
state.msg_id = self._state.write_data_as_message(
buffer, state.data, isinstance(state.request, TLRequest),
after_id=after_id
)
__log__.debug('Assigned msg_id = %d to %s (%x)',
state.msg_id, state.request.__class__.__name__,
id(state.request))
# TODO Yield in the inner loop -> Telegram "Invalid container". Why?
for state in state_list:
if not isinstance(state, list):
yield write_state(state)
else:
after_id = None
for s in state:
yield write_state(s, after_id)
after_id = s.msg_id
yield write_state(None)
def __str__(self):
return str(self._connection)

View File

@ -30,7 +30,7 @@ class MTProtoPlainSender:
body = bytes(request)
msg_id = self._state._get_new_msg_id()
await self._connection.send(
struct.pack('<QQi', 0, msg_id, len(body)) + body
struct.pack('<qqi', 0, msg_id, len(body)) + body
)
body = await self._connection.recv()

View File

@ -1,13 +1,19 @@
import asyncio
import collections
import logging
from . import MTProtoPlainSender, authenticator
from . import authenticator
from .mtprotolayer import MTProtoLayer
from .mtprotoplainsender import MTProtoPlainSender
from .requeststate import RequestState
from ..tl.tlobject import TLRequest
from .. import utils
from ..errors import (
BadMessageError, TypeNotFoundError, BrokenAuthKeyError, SecurityError,
rpc_message_to_error
BadMessageError, BrokenAuthKeyError, SecurityError, TypeNotFoundError,
InvalidChecksumError, rpc_message_to_error
)
from ..extensions import BinaryReader
from ..helpers import _ReadyQueue
from ..tl.core import RpcResult, MessageContainer, GzipPacked
from ..tl.functions.auth import LogOutRequest
from ..tl.types import (
@ -20,12 +26,6 @@ from ..utils import AsyncClassWrapper
__log__ = logging.getLogger(__name__)
# Place this object in the send queue when a reconnection is needed
# so there is an item to read and we can early quit the loop, since
# without this it will block until there's something in the queue.
_reconnect_sentinel = object()
class MTProtoSender:
"""
MTProto Mobile Protocol sender
@ -41,16 +41,15 @@ class MTProtoSender:
A new authorization key will be generated on connection if no other
key exists yet.
"""
def __init__(self, state, connection, loop, *,
retries=5, auto_reconnect=True, update_callback=None,
def __init__(self, loop, *,
retries=5, auto_reconnect=True, connect_timeout=None,
update_callback=None,
auth_key_callback=None, auto_reconnect_callback=None):
self.state = state
self._connection = connection
self._connection = None # MTProtoLayer, a.k.a. encrypted connection
self._loop = loop
self._ip = None
self._port = None
self._retries = retries
self._auto_reconnect = auto_reconnect
self._connect_timeout = connect_timeout
self._update_callback = update_callback
self._auth_key_callback = auth_key_callback
self._auto_reconnect_callback = auto_reconnect_callback
@ -69,27 +68,21 @@ class MTProtoSender:
self._send_loop_handle = None
self._recv_loop_handle = None
# Sending something shouldn't block
self._send_queue = _ContainerQueue()
# Outgoing messages are put in a queue and sent in a batch.
# Note that here we're also storing their ``_RequestState``.
# Note that it may also store lists (implying order must be kept).
self._send_queue = _ReadyQueue(self._loop)
# Telegram responds to messages out of order. Keep
# {id: Message} to set their Future result upon arrival.
self._pending_messages = {}
# Sent states are remembered until a response is received.
self._pending_state = {}
# Containers are accepted or rejected as a whole when any of
# its inner requests are acknowledged. For this purpose we
# all the sent containers here.
self._pending_containers = []
# We need to acknowledge every response from Telegram
# Responses must be acknowledged, and we can also batch these.
self._pending_ack = set()
# Similar to pending_messages but only for the last ack.
# Ack can't be put in the messages because Telegram never
# responds to acknowledges (they're just that, acknowledges),
# so it would grow to infinite otherwise, but on bad salt it's
# necessary to resend them just like everything else.
self._last_ack = None
# Similar to pending_messages but only for the last acknowledges.
# These can't go in pending_messages because no acknowledge for them
# is received, but we may still need to resend their state on bad salts.
self._last_acks = collections.deque(maxlen=10)
# Jump table from response ID to method that handles it
self._handlers = {
@ -111,18 +104,15 @@ class MTProtoSender:
# Public API
async def connect(self, ip, port):
async def connect(self, auth_key, connection):
"""
Connects to the specified ``ip:port``, and generates a new
authorization key for the `MTProtoSender.session` if it does
not exist yet.
Connects to the specified given connection using the given auth key.
"""
if self._user_connected:
__log__.info('User is already connected!')
return
self._ip = ip
self._port = port
self._connection = MTProtoLayer(connection, auth_key)
self._user_connected = True
await self._connect()
@ -134,53 +124,13 @@ class MTProtoSender:
Cleanly disconnects the instance from the network, cancels
all pending requests, and closes the send and receive loops.
"""
if not self._user_connected:
__log__.info('User is already disconnected!')
return
await self._disconnect()
async def _disconnect(self, error=None):
__log__.info('Disconnecting from {}...'.format(self._ip))
self._user_connected = False
try:
__log__.debug('Closing current connection...')
await self._connection.close()
finally:
__log__.debug('Cancelling {} pending message(s)...'
.format(len(self._pending_messages)))
for message in self._pending_messages.values():
if error and not message.future.done():
message.future.set_exception(error)
else:
message.future.cancel()
self._pending_messages.clear()
self._pending_ack.clear()
self._last_ack = None
if self._send_loop_handle:
__log__.debug('Cancelling the send loop...')
self._send_loop_handle.cancel()
if self._recv_loop_handle:
__log__.debug('Cancelling the receive loop...')
self._recv_loop_handle.cancel()
__log__.info('Disconnection from {} complete!'.format(self._ip))
if self._disconnected and not self._disconnected.done():
if error:
self._disconnected.set_exception(error)
else:
self._disconnected.set_result(None)
self._disconnect()
def send(self, request, ordered=False):
"""
This method enqueues the given request to be sent.
The request will be wrapped inside a `TLMessage` until its
response arrives, and the `Future` response of the `TLMessage`
is immediately returned so that one can further ``await`` it:
This method enqueues the given request to be sent. Its send
state will be saved until a response arrives, and a ``Future``
that will be resolved when the response arrives will be returned:
.. code-block:: python
@ -202,23 +152,23 @@ class MTProtoSender:
if not self._user_connected:
raise ConnectionError('Cannot send requests while disconnected')
if utils.is_list_like(request):
result = []
after = None
for r in request:
message = self.state.create_message(
r, loop=self._loop, after=after)
self._pending_messages[message.msg_id] = message
self._send_queue.put_nowait(message)
result.append(message.future)
after = ordered and message
return result
if not utils.is_list_like(request):
state = RequestState(request, self._loop)
self._send_queue.append(state)
return state.future
else:
message = self.state.create_message(request, loop=self._loop)
self._pending_messages[message.msg_id] = message
self._send_queue.put_nowait(message)
return message.future
states = []
futures = []
for req in request:
state = RequestState(req, self._loop)
states.append(state)
futures.append(state.future)
if ordered:
self._send_queue.append(states)
else:
self._send_queue.extend(states)
return futures
@property
def disconnected(self):
@ -239,12 +189,12 @@ class MTProtoSender:
authorization key if necessary, and starting the send and
receive loops.
"""
__log__.info('Connecting to {}:{}...'.format(self._ip, self._port))
__log__.info('Connecting to %s...', self._connection)
for retry in range(1, self._retries + 1):
try:
__log__.debug('Connection attempt {}...'.format(retry))
await self._connection.connect(self._ip, self._port)
except (asyncio.TimeoutError, OSError) as e:
await self._connection.connect(timeout=self._connect_timeout)
except (OSError, asyncio.TimeoutError) as e:
__log__.warning('Attempt {} at connecting failed: {}: {}'
.format(retry, type(e).__name__, e))
else:
@ -254,16 +204,17 @@ class MTProtoSender:
.format(self._retries))
__log__.debug('Connection success!')
if self.state.auth_key is None:
plain = MTProtoPlainSender(self._connection)
state = self._connection._state
if state.auth_key is None:
plain = MTProtoPlainSender(self._connection._connection)
for retry in range(1, self._retries + 1):
try:
__log__.debug('New auth_key attempt {}...'.format(retry))
self.state.auth_key, self.state.time_offset =\
state.auth_key, state.time_offset =\
await authenticator.do_authentication(plain)
if self._auth_key_callback:
await self._auth_key_callback(self.state.auth_key)
await self._auth_key_callback(state.auth_key)
break
except (SecurityError, AssertionError) as e:
@ -284,14 +235,50 @@ class MTProtoSender:
# First connection or manual reconnection after a failure
if self._disconnected is None or self._disconnected.done():
self._disconnected = self._loop.create_future()
__log__.info('Connection to {} complete!'.format(self._ip))
__log__.info('Connection to %s complete!', self._connection)
def _disconnect(self, error=None):
__log__.info('Disconnecting from %s...', self._connection)
self._user_connected = False
try:
__log__.debug('Closing current connection...')
self._connection.disconnect()
finally:
__log__.debug('Cancelling {} pending message(s)...'
.format(len(self._pending_state)))
for state in self._pending_state.values():
if error and not state.future.done():
state.future.set_exception(error)
else:
state.future.cancel()
self._pending_state.clear()
self._pending_ack.clear()
self._last_ack = None
if self._send_loop_handle:
__log__.debug('Cancelling the send loop...')
self._send_loop_handle.cancel()
if self._recv_loop_handle:
__log__.debug('Cancelling the receive loop...')
self._recv_loop_handle.cancel()
__log__.info('Disconnection from %s complete!', self._connection)
if self._disconnected and not self._disconnected.done():
if error:
self._disconnected.set_exception(error)
else:
self._disconnected.set_result(None)
async def _reconnect(self):
"""
Cleanly disconnects and then reconnects.
"""
self._reconnecting = True
self._send_queue.put_nowait(_reconnect_sentinel)
__log__.debug('Closing current connection...')
self._connection.disconnect()
__log__.debug('Awaiting for the send loop before reconnecting...')
await self._send_loop_handle
@ -300,23 +287,27 @@ class MTProtoSender:
await self._recv_loop_handle
__log__.debug('Closing current connection...')
await self._connection.close()
self._connection.disconnect()
self._reconnecting = False
# Start with a clean state (and thus session ID) to avoid old msgs
self._connection.reset_state()
retries = self._retries if self._auto_reconnect else 0
for retry in range(1, retries + 1):
try:
await self._connect()
for m in self._pending_messages.values():
self._send_queue.put_nowait(m)
except ConnectionError:
__log__.info('Failed reconnection retry %d/%d', retry, retries)
else:
self._send_queue.extend(self._pending_state.values())
self._pending_state.clear()
if self._auto_reconnect_callback:
self._loop.create_task(self._auto_reconnect_callback())
break
except ConnectionError:
__log__.info('Failed reconnection retry %d/%d', retry, retries)
else:
__log__.error('Failed to reconnect automatically.')
await self._disconnect(error=ConnectionError())
@ -326,23 +317,6 @@ class MTProtoSender:
if self._user_connected:
self._loop.create_task(self._reconnect())
def _clean_containers(self, msg_ids):
"""
Helper method to clean containers from the pending messages
once a wrapped msg_id of them has been acknowledged.
This is the only way we can resend TLMessage(MessageContainer)
on bad notifications and also mark them as received once any
of their inner TLMessage is acknowledged.
"""
for i in reversed(range(len(self._pending_containers))):
message = self._pending_containers[i]
for msg in message.obj.messages:
if msg.msg_id in msg_ids:
del self._pending_containers[i]
del self._pending_messages[message.msg_id]
break
# Loops
async def _send_loop(self):
@ -354,67 +328,31 @@ class MTProtoSender:
"""
while self._user_connected and not self._reconnecting:
if self._pending_ack:
self._last_ack = self.state.create_message(
MsgsAck(list(self._pending_ack)), loop=self._loop
)
self._send_queue.put_nowait(self._last_ack)
ack = RequestState(MsgsAck(list(self._pending_ack)), self._loop)
self._send_queue.append(ack)
self._last_acks.append(ack)
self._pending_ack.clear()
messages = await self._send_queue.get()
if messages == _reconnect_sentinel:
if self._reconnecting:
break
state_list = await self._send_queue.get(
self._connection._connection.disconnected)
if state_list is None:
break
try:
await self._connection.send(state_list)
except Exception:
__log__.exception('Unhandled error while sending data')
continue
for state in state_list:
if not isinstance(state, list):
if isinstance(state.request, TLRequest):
self._pending_state[state.msg_id] = state
else:
continue
if isinstance(messages, list):
message = self.state.create_message(
MessageContainer(messages), loop=self._loop)
self._pending_messages[message.msg_id] = message
self._pending_containers.append(message)
else:
message = messages
messages = [message]
__log__.debug(
'Packing %d outgoing message(s) %s...', len(messages),
', '.join(x.obj.__class__.__name__ for x in messages)
)
body = self.state.pack_message(message)
while not any(m.future.cancelled() for m in messages):
try:
__log__.debug('Sending {} bytes...'.format(len(body)))
await self._connection.send(body)
break
except asyncio.TimeoutError:
continue
except asyncio.CancelledError:
return
except Exception as e:
if isinstance(e, ConnectionError):
__log__.info('Connection reset while sending %s', e)
elif isinstance(e, OSError):
__log__.warning('OSError while sending %s', e)
else:
__log__.exception('Unhandled exception while receiving')
await asyncio.sleep(1, loop=self._loop)
self._start_reconnect()
break
else:
# Remove the cancelled messages from pending
__log__.info('Some futures were cancelled, aborted send')
self._clean_containers([m.msg_id for m in messages])
for m in messages:
if m.future.cancelled():
self._pending_messages.pop(m.msg_id, None)
else:
self._send_queue.put_nowait(m)
__log__.debug('Outgoing messages {} sent!'
.format(', '.join(str(m.msg_id) for m in messages)))
for s in state:
if isinstance(s.request, TLRequest):
self._pending_state[s.msg_id] = s
async def _recv_loop(self):
"""
@ -424,68 +362,44 @@ class MTProtoSender:
Besides `connect`, only this method ever receives data.
"""
while self._user_connected and not self._reconnecting:
__log__.debug('Receiving items from the network...')
try:
__log__.debug('Receiving items from the network...')
body = await self._connection.recv()
except asyncio.TimeoutError:
message = await self._connection.recv()
except TypeNotFoundError as e:
__log__.info('Type %08x not found, remaining data %r',
e.invalid_constructor_id, e.remaining)
continue
except asyncio.CancelledError:
return
except Exception as e:
if isinstance(e, ConnectionError):
__log__.info('Connection reset while receiving %s', e)
elif isinstance(e, OSError):
__log__.warning('OSError while receiving %s', e)
else:
__log__.exception('Unhandled exception while receiving')
await asyncio.sleep(1, loop=self._loop)
self._start_reconnect()
break
# TODO Check salt, session_id and sequence_number
__log__.debug('Decoding packet of %d bytes...', len(body))
try:
message = self.state.unpack_message(body)
except (BrokenAuthKeyError, BufferError) as e:
# The authorization key may be broken if a message was
# sent malformed, or if the authkey truly is corrupted.
#
# There may be a buffer error if Telegram's response was too
# short and hence not understood. Reset the authorization key
# and try again in either case.
#
# TODO Is it possible to detect malformed messages vs
# an actually broken authkey?
__log__.warning('Broken authorization key?: {}'.format(e))
self.state.auth_key = None
self._start_reconnect()
break
except SecurityError as e:
# A step while decoding had the incorrect data. This message
# should not be considered safe and it should be ignored.
__log__.warning('Security error while unpacking a '
'received message: {}'.format(e))
continue
except TypeNotFoundError as e:
# The payload inside the message was not a known TLObject.
__log__.info('Server replied with an unknown type {:08x}: {!r}'
.format(e.invalid_constructor_id, e.remaining))
'received message: %s', e)
continue
except InvalidChecksumError as e:
__log__.warning(
'Invalid checksum on the read packet (was %s expected %s)',
e.checksum, e.valid_checksum
)
except asyncio.CancelledError:
return
except Exception as e:
__log__.exception('Unhandled exception while unpacking %s',e)
await asyncio.sleep(1, loop=self._loop)
except (BrokenAuthKeyError, BufferError):
__log__.info('Broken authorization key; resetting')
self._connection._state.auth_key = None
self._start_reconnect()
return
except asyncio.IncompleteReadError:
__log__.info('Telegram closed the connection')
self._start_reconnect()
return
except Exception:
__log__.exception('Unhandled error while receiving data')
self._start_reconnect()
return
else:
try:
await self._process_message(message)
except asyncio.CancelledError:
return
except Exception as e:
__log__.exception('Unhandled exception while '
'processing %s', message)
await asyncio.sleep(1, loop=self._loop)
except Exception:
__log__.exception('Unhandled error while processing msgs')
# Response Handlers
@ -500,6 +414,30 @@ class MTProtoSender:
self._handle_update)
await handler(message)
def _pop_states(self, msg_id):
"""
Pops the states known to match the given ID from pending messages.
This method should be used when the response isn't specific.
"""
state = self._pending_state.pop(msg_id, None)
if state:
return [state]
to_pop = []
for state in self._pending_state.values():
if state.container_id == msg_id:
to_pop.append(state.msg_id)
if to_pop:
return [self._pending_state.pop(x) for x in to_pop]
for ack in self._last_acks:
if ack.msg_id == msg_id:
return [ack]
return []
async def _handle_rpc_result(self, message):
"""
Handles the result for Remote Procedure Calls:
@ -509,11 +447,11 @@ class MTProtoSender:
This is where the future results for sent requests are set.
"""
rpc_result = message.obj
message = self._pending_messages.pop(rpc_result.req_msg_id, None)
state = self._pending_state.pop(rpc_result.req_msg_id, None)
__log__.debug('Handling RPC result for message %d',
rpc_result.req_msg_id)
if not message:
if not state:
# TODO We should not get responses to things we never sent
# However receiving a File() with empty bytes is "common".
# See #658, #759 and #958. They seem to happen in a container
@ -529,22 +467,17 @@ class MTProtoSender:
if rpc_result.error:
error = rpc_message_to_error(rpc_result.error)
self._send_queue.put_nowait(self.state.create_message(
MsgsAck([message.msg_id]), loop=self._loop
))
self._send_queue.append(
RequestState(MsgsAck([state.msg_id]), loop=self._loop))
if not message.future.cancelled():
message.future.set_exception(error)
if not state.future.cancelled():
state.future.set_exception(error)
else:
# TODO Would be nice to avoid accessing a per-obj read_result
# Instead have a variable that indicated how the result should
# be read (an enum) and dispatch to read the result, mostly
# always it's just a normal TLObject.
with BinaryReader(rpc_result.body) as reader:
result = message.obj.read_result(reader)
result = state.request.read_result(reader)
if not message.future.cancelled():
message.future.set_result(result)
if not state.future.cancelled():
state.future.set_result(result)
async def _handle_container(self, message):
"""
@ -582,9 +515,9 @@ class MTProtoSender:
"""
pong = message.obj
__log__.debug('Handling pong for message %d', pong.msg_id)
message = self._pending_messages.pop(pong.msg_id, None)
if message:
message.future.set_result(pong)
state = self._pending_state.pop(pong.msg_id, None)
if state:
state.future.set_result(pong)
async def _handle_bad_server_salt(self, message):
"""
@ -596,18 +529,11 @@ class MTProtoSender:
"""
bad_salt = message.obj
__log__.debug('Handling bad salt for message %d', bad_salt.bad_msg_id)
self.state.salt = bad_salt.new_server_salt
if self._last_ack and bad_salt.bad_msg_id == self._last_ack.msg_id:
self._send_queue.put_nowait(self._last_ack)
return
self._connection._state.salt = bad_salt.new_server_salt
states = self._pop_states(bad_salt.bad_msg_id)
self._send_queue.extend(states)
try:
self._send_queue.put_nowait(
self._pending_messages[bad_salt.bad_msg_id])
except KeyError:
# May be MsgsAck, those are not saved in pending messages
__log__.info('Message %d not resent due to bad salt',
bad_salt.bad_msg_id)
__log__.debug('%d message(s) will be resent', len(states))
async def _handle_bad_notification(self, message):
"""
@ -618,44 +544,30 @@ class MTProtoSender:
error_code:int = BadMsgNotification;
"""
bad_msg = message.obj
msg = self._pending_messages.get(bad_msg.bad_msg_id)
states = self._pop_states(bad_msg.bad_msg_id)
__log__.debug('Handling bad msg %s', bad_msg)
if bad_msg.error_code in (16, 17):
# Sent msg_id too low or too high (respectively).
# Use the current msg_id to determine the right time offset.
to = self.state.update_time_offset(correct_msg_id=message.msg_id)
to = self._connection._state.update_time_offset(
correct_msg_id=message.msg_id)
__log__.info('System clock is wrong, set time offset to %ds', to)
# Correct the msg_id *of the message to resend*, not all.
#
# If we correct them all, new "bad message" would not find
# the old invalid IDs, causing all awaits to never finish.
if msg:
del self._pending_messages[msg.msg_id]
self.state.update_message_id(msg)
self._pending_messages[msg.msg_id] = msg
elif bad_msg.error_code == 32:
# msg_seqno too low, so just pump it up by some "large" amount
# TODO A better fix would be to start with a new fresh session ID
self.state._sequence += 64
self._connection._state._sequence += 64
elif bad_msg.error_code == 33:
# msg_seqno too high never seems to happen but just in case
self.state._sequence -= 16
self._connection._state._sequence -= 16
else:
if msg:
del self._pending_messages[msg.msg_id]
msg.future.set_exception(BadMessageError(bad_msg.error_code))
for state in states:
state.future.set_exception(BadMessageError(bad_msg.error_code))
return
# Messages are to be re-sent once we've corrected the issue
if msg:
self._send_queue.put_nowait(msg)
else:
# May be MsgsAck, those are not saved in pending messages
__log__.info('Message %d not resent due to bad msg',
bad_msg.bad_msg_id)
self._send_queue.extend(states)
__log__.debug('%d messages will be resent due to bad msg', len(states))
async def _handle_detailed_info(self, message):
"""
@ -690,7 +602,7 @@ class MTProtoSender:
"""
# TODO https://goo.gl/LMyN7A
__log__.debug('Handling new session created')
self.state.salt = message.obj.server_salt
self._connection._state.salt = message.obj.server_salt
async def _handle_ack(self, message):
"""
@ -709,14 +621,11 @@ class MTProtoSender:
"""
ack = message.obj
__log__.debug('Handling acknowledge for %s', str(ack.msg_ids))
if self._pending_containers:
self._clean_containers(ack.msg_ids)
for msg_id in ack.msg_ids:
msg = self._pending_messages.get(msg_id, None)
if msg and isinstance(msg.obj, LogOutRequest):
del self._pending_messages[msg_id]
msg.future.set_result(True)
state = self._pending_state.get(msg_id)
if state and isinstance(state.request, LogOutRequest):
del self._pending_state[msg_id]
state.future.set_result(True)
async def _handle_future_salts(self, message):
"""
@ -729,52 +638,20 @@ class MTProtoSender:
# TODO save these salts and automatically adjust to the
# correct one whenever the salt in use expires.
__log__.debug('Handling future salts for message %d', message.msg_id)
msg = self._pending_messages.pop(message.msg_id, None)
if msg:
msg.future.set_result(message.obj)
state = self._pending_state.pop(message.msg_id, None)
if state:
state.future.set_result(message.obj)
async def _handle_state_forgotten(self, message):
"""
Handles both :tl:`MsgsStateReq` and :tl:`MsgResendReq` by
enqueuing a :tl:`MsgsStateInfo` to be sent at a later point.
"""
self.send(MsgsStateInfo(req_msg_id=message.msg_id,
info=chr(1) * len(message.obj.msg_ids)))
self._send_queue.append(RequestState(MsgsStateInfo(
req_msg_id=message.msg_id, info=chr(1) * len(message.obj.msg_ids)),
loop=self._loop))
async def _handle_msg_all(self, message):
"""
Handles :tl:`MsgsAllInfo` by doing nothing (yet).
"""
class _ContainerQueue(asyncio.Queue):
"""
An asyncio queue that's aware of `MessageContainer` instances.
The `get` method returns either a single `TLMessage` or a list
of them that should be turned into a new `MessageContainer`.
Instances of this class can be replaced with the simpler
``asyncio.Queue`` when needed for testing purposes, and
a list won't be returned in said case.
"""
async def get(self):
result = await super().get()
if self.empty() or result == _reconnect_sentinel or\
isinstance(result.obj, MessageContainer):
return result
size = result.size()
result = [result]
while not self.empty():
item = self.get_nowait()
if (item == _reconnect_sentinel or
isinstance(item.obj, MessageContainer)
or size + item.size() > MessageContainer.MAXIMUM_SIZE):
self.put_nowait(item)
break
else:
size += item.size()
result.append(item)
return result

View File

@ -8,7 +8,8 @@ from ..crypto import AES
from ..errors import SecurityError, BrokenAuthKeyError
from ..extensions import BinaryReader
from ..tl.core import TLMessage
from ..tl.tlobject import TLRequest
from ..tl.functions import InvokeAfterMsgRequest
from ..tl.core.gzippacked import GzipPacked
__log__ = logging.getLogger(__name__)
@ -27,6 +28,14 @@ class MTProtoState:
for all these is not a good idea as each need their own authkey, and
the concept of "copying" sessions with the unnecessary entities or
updates state for these connections doesn't make sense.
While it would be possible to have a `MTProtoPlainState` that does no
encryption so that it was usable through the `MTProtoLayer` and thus
avoid the need for a `MTProtoPlainSender`, the `MTProtoLayer` is more
focused to efficiency and this state is also more advanced (since it
supports gzipping and invoking after other message IDs). There are too
many methods that would be needed to make it convenient to use for the
authentication process, at which point the `MTProtoPlainSender` is better.
"""
def __init__(self, auth_key):
# Session IDs can be random on every connection
@ -37,20 +46,6 @@ class MTProtoState:
self._sequence = 0
self._last_msg_id = 0
def create_message(self, obj, *, loop, after=None):
"""
Creates a new `telethon.tl.tl_message.TLMessage` from
the given `telethon.tl.tlobject.TLObject` instance.
"""
return TLMessage(
msg_id=self._get_new_msg_id(),
seq_no=self._get_seq_no(isinstance(obj, TLRequest)),
obj=obj,
after_id=after.msg_id if after else None,
out=True, # Pre-convert the request into bytes
loop=loop
)
def update_message_id(self, message):
"""
Updates the message ID to a new one,
@ -74,14 +69,31 @@ class MTProtoState:
return aes_key, aes_iv
def pack_message(self, message):
def write_data_as_message(self, buffer, data, content_related,
*, after_id=None):
"""
Packs the given `telethon.tl.tl_message.TLMessage` using the
current authorization key following MTProto 2.0 guidelines.
Writes a message containing the given data into buffer.
See https://core.telegram.org/mtproto/description.
Returns the message id.
"""
data = struct.pack('<qq', self.salt, self.id) + bytes(message)
msg_id = self._get_new_msg_id()
seq_no = self._get_seq_no(content_related)
if after_id is None:
body = GzipPacked.gzip_if_smaller(content_related, data)
else:
body = GzipPacked.gzip_if_smaller(content_related,
bytes(InvokeAfterMsgRequest(after_id, data)))
buffer.write(struct.pack('<qii', msg_id, seq_no, len(body)))
buffer.write(body)
return msg_id
def encrypt_message_data(self, data):
"""
Encrypts the given message data using the current authorization key
following MTProto 2.0 guidelines core.telegram.org/mtproto/description.
"""
data = struct.pack('<qq', self.salt, self.id) + data
padding = os.urandom(-(len(data) + 12) % 16 + 12)
# Being substr(what, offset, length); x = 0 for client
@ -97,16 +109,18 @@ class MTProtoState:
return (key_id + msg_key +
AES.encrypt_ige(data + padding, aes_key, aes_iv))
def unpack_message(self, body):
def decrypt_message_data(self, body):
"""
Inverse of `pack_message` for incoming server messages.
Inverse of `encrypt_message_data` for incoming server messages.
"""
if len(body) < 8:
# TODO If len == 4, raise HTTPErrorCode(-little endian int)
if body == b'l\xfe\xff\xff':
raise BrokenAuthKeyError()
else:
raise BufferError("Can't decode packet ({})".format(body))
# TODO Check salt, session_id and sequence_number
key_id = struct.unpack('<Q', body[:8])[0]
if key_id != self.auth_key.key_id:
raise SecurityError('Server replied with an invalid auth key')
@ -136,7 +150,7 @@ class MTProtoState:
# reader isn't used for anything else after this, it's unnecessary.
obj = reader.tgread_object()
return TLMessage(remote_msg_id, remote_sequence, obj, loop=None)
return TLMessage(remote_msg_id, remote_sequence, obj)
def _get_new_msg_id(self):
"""

View File

@ -0,0 +1,18 @@
import asyncio
class RequestState:
"""
This request state holds several information relevant to sent messages,
in particular the message ID assigned to the request, the container ID
it belongs to, the request itself, the request as bytes, and the future
result that will eventually be resolved.
"""
__slots__ = ('container_id', 'msg_id', 'request', 'data', 'future')
def __init__(self, request, loop):
self.container_id = None
self.msg_id = None
self.request = request
self.data = bytes(request)
self.future = asyncio.Future(loop=loop)

View File

@ -11,15 +11,14 @@ class GzipPacked(TLObject):
self.data = data
@staticmethod
def gzip_if_smaller(request):
def gzip_if_smaller(content_related, data):
"""Calls bytes(request), and based on a certain threshold,
optionally gzips the resulting data. If the gzipped data is
smaller than the original byte array, this is returned instead.
Note that this only applies to content related requests.
"""
data = bytes(request)
if isinstance(request, TLRequest) and len(data) > 512:
if content_related and len(data) > 512:
gzipped = bytes(GzipPacked(data))
return gzipped if len(gzipped) < len(data) else data
else:

View File

@ -27,11 +27,6 @@ class MessageContainer(TLObject):
],
}
def __bytes__(self):
return struct.pack(
'<Ii', MessageContainer.CONSTRUCTOR_ID, len(self.messages)
) + b''.join(bytes(m) for m in self.messages)
@classmethod
def from_reader(cls, reader):
# This assumes that .read_* calls are done in the order they appear
@ -43,5 +38,5 @@ class MessageContainer(TLObject):
before = reader.tell_position()
obj = reader.tgread_object() # May over-read e.g. RpcResult
reader.set_position(before + length)
messages.append(TLMessage(msg_id, seq_no, obj, loop=None))
messages.append(TLMessage(msg_id, seq_no, obj))
return MessageContainer(messages)

View File

@ -1,10 +1,6 @@
import asyncio
import logging
import struct
from .gzippacked import GzipPacked
from .. import TLObject
from ..functions import InvokeAfterMsgRequest
__log__ = logging.getLogger(__name__)
@ -17,83 +13,26 @@ class TLMessage(TLObject):
message msg_id:long seqno:int bytes:int body:bytes = Message;
Each message has its own unique identifier, and the body is simply
the serialized request that should be executed on the server. Then
Telegram will, at some point, respond with the result for this msg.
the serialized request that should be executed on the server, or
the response object from Telegram. Since the body is always a valid
object, it makes sense to store the object and not the bytes to
ease working with them.
Thus it makes sense that requests and their result are bound to a
sent `TLMessage`, and this result can be represented as a `Future`
that will eventually be set with either a result, error or cancelled.
There is no need to add serializing logic here since that can be
inlined and is unlikely to change. Thus these are only needed to
encapsulate responses.
"""
def __init__(self, msg_id, seq_no, obj, *, loop, out=False, after_id=0):
SIZE_OVERHEAD = 12
def __init__(self, msg_id, seq_no, obj):
self.msg_id = msg_id
self.seq_no = seq_no
self.obj = obj
self.container_msg_id = None
# If no loop is given then it is an incoming message.
# Only outgoing messages need the future to await them.
self.future = loop.create_future() if loop else None
# After which message ID this one should run. We do this so
# InvokeAfterMsgRequest is transparent to the user and we can
# easily invoke after while confirming the original request.
# TODO Currently we don't update this if another message ID changes
self.after_id = after_id
# There are two use-cases for the TLMessage, outgoing and incoming.
# Outgoing messages are meant to be serialized and sent across the
# network so it makes sense to pack them as early as possible and
# avoid this computation if it needs to be resent, and also shows
# serializing-errors as early as possible (foreground task).
#
# We assume obj won't change so caching the bytes is safe to do.
# Caching bytes lets us get the size in a fast way, necessary for
# knowing whether a container can be sent (<1MB) or not (too big).
#
# Incoming messages don't really need this body, but we save the
# msg_id and seq_no inside the body for consistency and raise if
# one tries to bytes()-ify the entire message (len == 12).
if not out:
self._body = struct.pack('<qi', msg_id, seq_no)
else:
try:
if self.after_id is None:
body = GzipPacked.gzip_if_smaller(self.obj)
else:
body = GzipPacked.gzip_if_smaller(
InvokeAfterMsgRequest(self.after_id, self.obj))
except Exception:
# struct.pack doesn't give a lot of information about
# why it may fail so log the exception AND the object
__log__.exception('Failed to pack %s', self.obj)
raise
self._body = struct.pack('<qii', msg_id, seq_no, len(body)) + body
def to_dict(self):
return {
'_': 'TLMessage',
'msg_id': self.msg_id,
'seq_no': self.seq_no,
'obj': self.obj,
'container_msg_id': self.container_msg_id
'obj': self.obj
}
@property
def msg_id(self):
return struct.unpack('<q', self._body[:8])[0]
@msg_id.setter
def msg_id(self, value):
self._body = struct.pack('<q', value) + self._body[8:]
@property
def seq_no(self):
return struct.unpack('<i', self._body[8:12])[0]
def __bytes__(self):
if len(self._body) == 12: # msg_id, seqno
raise TypeError('Incoming messages should not be bytes()-ed')
return self._body
def size(self):
return len(self._body)

View File

@ -260,10 +260,7 @@ class Conversation(ChatGetter):
if isinstance(event, type):
event = event()
# Since we await resolve here we don't need to await resolved.
# We know it has already been resolved, unlike when normally
# adding an event handler, for which a task is created to resolve.
await event.resolve()
await event.resolve(self._client)
counter = Conversation._custom_counter
Conversation._custom_counter += 1

View File

@ -1,5 +1,5 @@
from . import Draft
from .. import TLObject, types
from .. import TLObject, types, functions
from ... import utils
@ -96,6 +96,17 @@ class Dialog:
return await self._client.send_message(
self.input_entity, *args, **kwargs)
async def delete(self):
if self.is_channel:
await self._client(functions.channels.LeaveChannelRequest(
self.input_entity))
else:
if self.is_group:
await self._client(functions.messages.DeleteChatUserRequest(
self.entity.id, types.InputPeerSelf()))
await self._client(functions.messages.DeleteHistoryRequest(
self.input_entity, 0))
def to_dict(self):
return {
'_': 'Dialog',

View File

@ -3,9 +3,10 @@ import datetime
from .. import TLObject
from ..functions.messages import SaveDraftRequest
from ..types import UpdateDraftMessage, DraftMessage
from ... import default
from ...errors import RPCError
from ...extensions import markdown
from ...utils import Default, get_peer_id, get_input_peer
from ...utils import get_peer_id, get_input_peer
class Draft:
@ -116,7 +117,7 @@ class Draft:
return not self._text
async def set_message(
self, text=None, reply_to=0, parse_mode=Default,
self, text=None, reply_to=0, parse_mode=default,
link_preview=None):
"""
Changes the draft message on the Telegram servers. The changes are
@ -163,7 +164,7 @@ class Draft:
return result
async def send(self, clear=True, parse_mode=Default):
async def send(self, clear=True, parse_mode=default):
"""
Sends the contents of this draft to the dialog. This is just a
wrapper around ``send_message(dialog.input_entity, *args, **kwargs)``.

View File

@ -1,7 +1,7 @@
import hashlib
from .. import functions, types
from ... import utils
from ... import default, utils
class InlineBuilder:
@ -55,7 +55,7 @@ class InlineBuilder:
async def article(
self, title, description=None,
*, url=None, thumb=None, content=None,
id=None, text=None, parse_mode=utils.Default, link_preview=True,
id=None, text=None, parse_mode=default, link_preview=True,
geo=None, period=60, contact=None, game=False, buttons=None
):
"""
@ -105,7 +105,7 @@ class InlineBuilder:
async def photo(
self, file, *, id=None,
text=None, parse_mode=utils.Default, link_preview=True,
text=None, parse_mode=default, link_preview=True,
geo=None, period=60, contact=None, game=False, buttons=None
):
"""
@ -144,7 +144,7 @@ class InlineBuilder:
self, file, title=None, *, description=None, type=None,
mime_type=None, attributes=None, force_document=False,
voice_note=False, video_note=False, use_cache=True, id=None,
text=None, parse_mode=utils.Default, link_preview=True,
text=None, parse_mode=default, link_preview=True,
geo=None, period=60, contact=None, game=False, buttons=None
):
"""
@ -219,7 +219,7 @@ class InlineBuilder:
async def game(
self, short_name, *, id=None,
text=None, parse_mode=utils.Default, link_preview=True,
text=None, parse_mode=default, link_preview=True,
geo=None, period=60, contact=None, game=False, buttons=None
):
"""
@ -247,7 +247,7 @@ class InlineBuilder:
async def _message(
self, *,
text=None, parse_mode=utils.Default, link_preview=True,
text=None, parse_mode=default, link_preview=True,
geo=None, period=60, contact=None, game=False, buttons=None
):
if sum(1 for x in (text, geo, contact, game) if x) != 1:

View File

@ -46,14 +46,6 @@ VALID_USERNAME_RE = re.compile(
)
class Default:
"""
Sentinel value to indicate that the default value should be used.
Currently used for the ``parse_mode``, where a ``None`` mode should
be considered different from using the default.
"""
def chunks(iterable, size=100):
"""
Turns the given iterable into chunks of the specified size,
@ -271,9 +263,41 @@ def get_input_photo(photo):
if isinstance(photo, types.PhotoEmpty):
return types.InputPhotoEmpty()
if isinstance(photo, types.messages.ChatFull):
photo = photo.full_chat
if isinstance(photo, types.ChannelFull):
return get_input_photo(photo.chat_photo)
elif isinstance(photo, types.UserFull):
return get_input_photo(photo.profile_photo)
elif isinstance(photo, (types.Channel, types.Chat, types.User)):
return get_input_photo(photo.photo)
if isinstance(photo, (types.UserEmpty, types.ChatEmpty,
types.ChatForbidden, types.ChannelForbidden)):
return types.InputPhotoEmpty()
_raise_cast_fail(photo, 'InputPhoto')
def get_input_chat_photo(photo):
"""Similar to :meth:`get_input_peer`, but for chat photos"""
try:
if photo.SUBCLASS_OF_ID == 0xd4eb2d74: # crc32(b'InputChatPhoto')
return photo
elif photo.SUBCLASS_OF_ID == 0xe7655f1f: # crc32(b'InputFile'):
return types.InputChatUploadedPhoto(photo)
except AttributeError:
_raise_cast_fail(photo, 'InputChatPhoto')
photo = get_input_photo(photo)
if isinstance(photo, types.InputPhoto):
return types.InputChatPhoto(photo)
elif isinstance(photo, types.InputPhotoEmpty):
return types.InputChatPhotoEmpty()
_raise_cast_fail(photo, 'InputChatPhoto')
def get_input_geo(geo):
"""Similar to :meth:`get_input_peer`, but for geo points"""
try:

View File

@ -25,7 +25,9 @@ AUTO_CASTS = {
'InputNotifyPeer': 'await client._get_input_notify({})',
'InputMedia': 'utils.get_input_media({})',
'InputPhoto': 'utils.get_input_photo({})',
'InputMessage': 'utils.get_input_message({})'
'InputMessage': 'utils.get_input_message({})',
'InputDocument': 'utils.get_input_document({})',
'InputChatPhoto': 'utils.get_input_chat_photo({})',
}
NAMED_AUTO_CASTS = {

View File

@ -1,5 +0,0 @@
from .crypto_test import CryptoTests
from .network_test import NetworkTests
from .parser_test import ParserTests
from .tl_test import TLTests
from .utils_test import UtilsTests

View File

@ -1,143 +0,0 @@
import unittest
from hashlib import sha1
import telethon.helpers as utils
from telethon.crypto import AES, Factorization
# from crypto.PublicKey import RSA as PyCryptoRSA
class CryptoTests(unittest.TestCase):
def setUp(self):
# Test known values
self.key = b'\xd1\xf4MXy\x0c\xf8/z,\xe9\xf9\xa4\x17\x04\xd9C\xc9\xaba\x81\xf3\xf8\xdd\xcb\x0c6\x92\x01\x1f\xc2y'
self.iv = b':\x02\x91x\x90Dj\xa6\x03\x90C\x08\x9e@X\xb5E\xffwy\xf3\x1c\xde\xde\xfbo\x8dm\xd6e.Z'
self.plain_text = b'Non encrypted text :D'
self.plain_text_padded = b'My len is more uniform, promise!'
self.cipher_text = b'\xb6\xa7\xec.\xb9\x9bG\xcb\xe9{\x91[\x12\xfc\x84D\x1c' \
b'\x93\xd9\x17\x03\xcd\xd6\xb1D?\x98\xd2\xb5\xa5U\xfd'
self.cipher_text_padded = b"W\xd1\xed'\x01\xa6c\xc3\xcb\xef\xaa\xe5\x1d\x1a" \
b"[\x1b\xdf\xcdI\x1f>Z\n\t\xb9\xd2=\xbaF\xd1\x8e'"
def test_sha1(self):
string = 'Example string'
hash_sum = sha1(string.encode('utf-8')).digest()
expected = b'\nT\x92|\x8d\x06:)\x99\x04\x8e\xf8j?\xc4\x8e\xd3}m9'
self.assertEqual(hash_sum, expected,
msg='Invalid sha1 hash_sum representation (should be {}, but is {})'
.format(expected, hash_sum))
@unittest.skip("test_aes_encrypt needs fix")
def test_aes_encrypt(self):
value = AES.encrypt_ige(self.plain_text, self.key, self.iv)
take = 16 # Don't take all the bytes, since latest involve are random padding
self.assertEqual(value[:take], self.cipher_text[:take],
msg='Ciphered text ("{}") does not equal expected ("{}")'
.format(value[:take], self.cipher_text[:take]))
value = AES.encrypt_ige(self.plain_text_padded, self.key, self.iv)
self.assertEqual(value, self.cipher_text_padded,
msg='Ciphered text ("{}") does not equal expected ("{}")'
.format(value, self.cipher_text_padded))
def test_aes_decrypt(self):
# The ciphered text must always be padded
value = AES.decrypt_ige(self.cipher_text_padded, self.key, self.iv)
self.assertEqual(value, self.plain_text_padded,
msg='Decrypted text ("{}") does not equal expected ("{}")'
.format(value, self.plain_text_padded))
@unittest.skip("test_calc_key needs fix")
def test_calc_key(self):
# TODO Upgrade test for MtProto 2.0
shared_key = b'\xbc\xd2m\xb7\xcav\xf4][\x88\x83\' \xf3\x11\x8as\xd04\x941\xae' \
b'*O\x03\x86\x9a/H#\x1a\x8c\xb5j\xe9$\xe0IvCm^\xe70\x1a5C\t\x16' \
b'\x03\xd2\x9d\xa9\x89\xd6\xce\x08P\x0fdr\xa0\xb3\xeb\xfecv\x1a' \
b'\xdfJ\x14\x96\x98\x16\xa3G\xab\x04\x14!\\\xeb\n\xbcn\xdf\xc4%' \
b'\xc6\t\xb7\x16\x14\x9c\'\x81\x15=\xb0\xaf\x0e\x0bR\xaa\x0466s' \
b'\xf0\xcf\xb7\xb8>,D\x94x\xd7\xf8\xe0\x84\xcb%\xd3\x05\xb2\xe8' \
b'\x95Mr?\xa2\xe8In\xf9\x0b[E\x9b\xaa\x0cX\x7f\x0ei\xde\xeed\x1d' \
b'x/J\xce\xea^}0;\xa83B\xbbR\xa1\xbfe\x04\xb9\x1e\xa1"f=\xa5M@' \
b'\x9e\xdd\x81\x80\xc9\xa5\xfb\xfcg\xdd\x15\x03p!\x0ffD\x16\x892' \
b'\xea\xca\xb1A\x99O\xa94P\xa9\xa2\xc6;\xb2C9\x1dC5\xd2\r\xecL' \
b'\xd9\xabw-\x03\ry\xc2v\x17]\x02\x15\x0cBa\x97\xce\xa5\xb1\xe4]' \
b'\x8e\xe0,\xcfC{o\xfa\x99f\xa4pM\x00'
# Calculate key being the client
msg_key = b'\xba\x1a\xcf\xda\xa8^Cbl\xfa\xb6\x0c:\x9b\xb0\xfc'
key, iv = utils.calc_key(shared_key, msg_key, client=True)
expected_key = b"\xaf\xe3\x84Qm\xe0!\x0c\xd91\xe4\x9a\xa0v_gc" \
b"x\xa1\xb0\xc9\xbc\x16'v\xcf,\x9dM\xae\xc6\xa5"
expected_iv = b'\xb8Q\xf3\xc5\xa3]\xc6\xdf\x9e\xe0Q\xbd"\x8d' \
b'\x13\t\x0e\x9a\x9d^8\xa2\xf8\xe7\x00w\xd9\xc1' \
b'\xa7\xa0\xf7\x0f'
self.assertEqual(key, expected_key,
msg='Invalid key (expected ("{}"), got ("{}"))'
.format(expected_key, key))
self.assertEqual(iv, expected_iv,
msg='Invalid IV (expected ("{}"), got ("{}"))'
.format(expected_iv, iv))
# Calculate key being the server
msg_key = b'\x86m\x92i\xcf\x8b\x93\xaa\x86K\x1fi\xd04\x83]'
key, iv = utils.calc_key(shared_key, msg_key, client=False)
expected_key = b'\xdd0X\xb6\x93\x8e\xc9y\xef\x83\xf8\x8cj' \
b'\xa7h\x03\xe2\xc6\xb16\xc5\xbb\xfc\xe7' \
b'\xdf\xd6\xb1g\xf7u\xcfk'
expected_iv = b'\xdcL\xc2\x18\x01J"X\x86lb\xb6\xb547\xfd' \
b'\xe2a4\xb6\xaf}FS\xd7[\xe0N\r\x19\xfb\xbc'
self.assertEqual(key, expected_key,
msg='Invalid key (expected ("{}"), got ("{}"))'
.format(expected_key, key))
self.assertEqual(iv, expected_iv,
msg='Invalid IV (expected ("{}"), got ("{}"))'
.format(expected_iv, iv))
def test_generate_key_data_from_nonce(self):
server_nonce = int.from_bytes(b'The 16-bit nonce', byteorder='little')
new_nonce = int.from_bytes(b'The new, calculated 32-bit nonce', byteorder='little')
key, iv = utils.generate_key_data_from_nonce(server_nonce, new_nonce)
expected_key = b'/\xaa\x7f\xa1\xfcs\xef\xa0\x99zh\x03M\xa4\x8e\xb4\xab\x0eE]b\x95|\xfe\xc0\xf8\x1f\xd4\xa0\xd4\xec\x91'
expected_iv = b'\xf7\xae\xe3\xc8+=\xc2\xb8\xd1\xe1\x1b\x0e\x10\x07\x9fn\x9e\xdc\x960\x05\xf9\xea\xee\x8b\xa1h The '
self.assertEqual(key, expected_key,
msg='Key ("{}") does not equal expected ("{}")'
.format(key, expected_key))
self.assertEqual(iv, expected_iv,
msg='IV ("{}") does not equal expected ("{}")'
.format(iv, expected_iv))
# test_fringerprint_from_key can't be skipped due to ImportError
# def test_fingerprint_from_key(self):
# assert rsa._compute_fingerprint(PyCryptoRSA.importKey(
# '-----BEGIN RSA PUBLIC KEY-----\n'
# 'MIIBCgKCAQEAwVACPi9w23mF3tBkdZz+zwrzKOaaQdr01vAbU4E1pvkfj4sqDsm6\n'
# 'lyDONS789sVoD/xCS9Y0hkkC3gtL1tSfTlgCMOOul9lcixlEKzwKENj1Yz/s7daS\n'
# 'an9tqw3bfUV/nqgbhGX81v/+7RFAEd+RwFnK7a+XYl9sluzHRyVVaTTveB2GazTw\n'
# 'Efzk2DWgkBluml8OREmvfraX3bkHZJTKX4EQSjBbbdJ2ZXIsRrYOXfaA+xayEGB+\n'
# '8hdlLmAjbCVfaigxX0CDqWeR1yFL9kwd9P0NsZRPsmoqVwMbMu7mStFai6aIhc3n\n'
# 'Slv8kg9qv1m6XHVQY3PnEw+QQtqSIXklHwIDAQAB\n'
# '-----END RSA PUBLIC KEY-----'
# )) == b'!k\xe8l\x02+\xb4\xc3', 'Wrong fingerprint calculated'
def test_factorize(self):
pq = 3118979781119966969
p, q = Factorization.factorize(pq)
if p > q:
p, q = q, p
self.assertEqual(p, 1719614201,
msg='Factorized pair did not yield the correct result')
self.assertEqual(q, 1813767169,
msg='Factorized pair did not yield the correct result')

View File

@ -1,49 +0,0 @@
import unittest
import os
from io import BytesIO
from random import randint
from hashlib import sha256
from telethon import TelegramClient
# Fill in your api_id and api_hash when running the tests
# and REMOVE THEM once you've finished testing them.
api_id = None
api_hash = None
class HigherLevelTests(unittest.TestCase):
def setUp(self):
if not api_id or not api_hash:
raise ValueError('Please fill in both your api_id and api_hash.')
@unittest.skip("you can't seriously trash random mobile numbers like that :)")
def test_cdn_download(self):
client = TelegramClient(None, api_id, api_hash)
client.session.set_dc(0, '149.154.167.40', 80)
self.assertTrue(client.connect())
try:
phone = '+999662' + str(randint(0, 9999)).zfill(4)
client.send_code_request(phone)
client.sign_up('22222', 'Test', 'DC')
me = client.get_me()
data = os.urandom(2 ** 17)
client.send_file(
me, data,
progress_callback=lambda c, t:
print('test_cdn_download:uploading {:.2%}...'.format(c/t))
)
msg = client.get_messages(me)[1][0]
out = BytesIO()
client.download_media(msg, out)
self.assertEqual(sha256(data).digest(), sha256(out.getvalue()).digest())
out = BytesIO()
client.download_media(msg, out) # Won't redirect
self.assertEqual(sha256(data).digest(), sha256(out.getvalue()).digest())
client.log_out()
finally:
client.disconnect()

View File

@ -1,44 +0,0 @@
import random
import socket
import threading
import unittest
import telethon.network.authenticator as authenticator
from telethon.extensions import TcpClient
from telethon.network import Connection
def run_server_echo_thread(port):
def server_thread():
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(('', port))
s.listen(1)
connection, address = s.accept()
with connection:
data = connection.recv(16)
connection.send(data)
server = threading.Thread(target=server_thread)
server.start()
class NetworkTests(unittest.TestCase):
@unittest.skip("test_tcp_client needs fix")
def test_tcp_client(self):
port = random.randint(50000, 60000) # Arbitrary non-privileged port
run_server_echo_thread(port)
msg = b'Unit testing...'
client = TcpClient()
client.connect('localhost', port)
client.write(msg)
self.assertEqual(msg, client.read(15),
msg='Read message does not equal sent message')
client.close()
@unittest.skip("Some parameters changed, so IP doesn't go there anymore.")
def test_authenticator(self):
transport = Connection('149.154.167.91', 443)
self.assertTrue(authenticator.do_authentication(transport))
transport.close()

View File

@ -1,8 +0,0 @@
import unittest
class ParserTests(unittest.TestCase):
"""There are no tests yet"""
@unittest.skip("there should be parser tests")
def test_parser(self):
self.assertTrue(True)

View File

@ -1,8 +0,0 @@
import unittest
class TLTests(unittest.TestCase):
"""There are no tests yet"""
@unittest.skip("there should be TL tests")
def test_tl(self):
self.assertTrue(True)

View File

@ -1,66 +0,0 @@
import os
import unittest
from telethon.tl import TLObject
from telethon.extensions import BinaryReader
class UtilsTests(unittest.TestCase):
def test_binary_writer_reader(self):
# Test that we can read properly
data = b'\x01\x05\x00\x00\x00\r\x00\x00\x00\x00\x00\x00\x00\x00\x00' \
b'\x88A\x00\x00\x00\x00\x00\x009@\x1a\x1b\x1c\x1d\x1e\x1f ' \
b'\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00' \
b'\x00\x80'
with BinaryReader(data) as reader:
value = reader.read_byte()
self.assertEqual(value, 1,
msg='Example byte should be 1 but is {}'.format(value))
value = reader.read_int()
self.assertEqual(value, 5,
msg='Example integer should be 5 but is {}'.format(value))
value = reader.read_long()
self.assertEqual(value, 13,
msg='Example long integer should be 13 but is {}'.format(value))
value = reader.read_float()
self.assertEqual(value, 17.0,
msg='Example float should be 17.0 but is {}'.format(value))
value = reader.read_double()
self.assertEqual(value, 25.0,
msg='Example double should be 25.0 but is {}'.format(value))
value = reader.read(7)
self.assertEqual(value, bytes([26, 27, 28, 29, 30, 31, 32]),
msg='Example bytes should be {} but is {}'
.format(bytes([26, 27, 28, 29, 30, 31, 32]), value))
value = reader.read_large_int(128, signed=False)
self.assertEqual(value, 2**127,
msg='Example large integer should be {} but is {}'.format(2**127, value))
def test_binary_tgwriter_tgreader(self):
small_data = os.urandom(33)
small_data_padded = os.urandom(19) # +1 byte for length = 20 (%4 = 0)
large_data = os.urandom(999)
large_data_padded = os.urandom(1024)
data = (small_data, small_data_padded, large_data, large_data_padded)
string = 'Testing Telegram strings, this should work properly!'
serialized = b''.join(TLObject.serialize_bytes(d) for d in data) + \
TLObject.serialize_bytes(string)
with BinaryReader(serialized) as reader:
# And then try reading it without errors (it should be unharmed!)
for datum in data:
value = reader.tgread_bytes()
self.assertEqual(value, datum,
msg='Example bytes should be {} but is {}'.format(datum, value))
value = reader.tgread_string()
self.assertEqual(value, string,
msg='Example string should be {} but is {}'.format(string, value))