Merge branch 'master' of github.com:lonamiwebs/Telethon

This commit is contained in:
Tanuj 2017-10-01 11:52:55 +01:00
commit a6e295da65
31 changed files with 1674 additions and 1299 deletions

View File

@ -1,26 +1,9 @@
<!--
Please remember that issues here should be related to the library itself and NOT your code.
0. The library is Python 3.x, not Python 2.x.
1. If you're posting an issue, make sure it's a bug in the library, not in your code.
2. If you're posting a question, make sure you have read and tried enough things first.
3. Show as much information as possible, including your failed attempts, and the full console output (to include the whole traceback with line numbers).
4. Good looking issues are a lot more appealing. If you need help check out https://guides.github.com/features/mastering-markdown/.
Python 2 is NOT supported. Make sure you're using the latest version of Telethon before reporting:
pip install telethon --upgrade
Some questions are okay, but make sure you've investigated enough on your own before or you will end up on the Wall of Shame:
https://github.com/LonamiWebs/Telethon/wiki/Wall-of-Shame.
You may also want to watch "How (not) to ask a technical question" over https://youtu.be/53zkBvL4ZB4
-->
### What went wrong
Describe what happened or what the error you have is.
```
// paste the crash log here if any
```
### What I've done
Either a code example of what you were trying to do, or steps to reproduce, or methods you have tried invoking.
```python
# Add your Python code here
```
### More information
If you think other information can be relevant (e.g. operative system or other variables), add it here.

View File

@ -199,14 +199,27 @@ def generate_index(folder, original_paths):
def get_description(arg):
"""Generates a proper description for the given argument"""
desc = []
otherwise = False
if arg.can_be_inferred:
desc.append('If left to None, it will be inferred automatically.')
if arg.is_vector:
desc.append('A list must be supplied for this argument.')
if arg.is_generic:
desc.append('A different Request must be supplied for this argument.')
if arg.is_flag:
desc.append('If left unspecified, it will be inferred automatically.')
otherwise = True
elif arg.is_flag:
desc.append('This argument can be omitted.')
otherwise = True
if arg.is_vector:
if arg.is_generic:
desc.append('A list of other Requests must be supplied.')
else:
desc.append('A list must be supplied.')
elif arg.is_generic:
desc.append('A different Request must be supplied for this argument.')
else:
otherwise = False # Always reset to false if no other text is added
if otherwise:
desc.insert(1, 'Otherwise,')
desc[-1] = desc[-1][:1].lower() + desc[-1][1:]
return ' '.join(desc)
@ -218,6 +231,7 @@ def generate_documentation(scheme_file):
original_paths = {
'css': 'css/docs.css',
'arrow': 'img/arrow.svg',
'404': '404.html',
'index_all': 'index.html',
'index_types': 'types/index.html',
'index_methods': 'methods/index.html',
@ -360,7 +374,8 @@ def generate_documentation(scheme_file):
for tltype, constructors in tltypes.items():
filename = get_path_for_type(tltype)
out_dir = os.path.dirname(filename)
os.makedirs(out_dir, exist_ok=True)
if out_dir:
os.makedirs(out_dir, exist_ok=True)
# Since we don't have access to the full TLObject, split the type
if '.' in tltype:
@ -503,15 +518,26 @@ def generate_documentation(scheme_file):
methods = sorted(methods, key=lambda m: m.name)
constructors = sorted(constructors, key=lambda c: c.name)
def fmt(xs):
ys = {x: get_class_name(x) for x in xs} # cache TLObject: display
zs = {} # create a dict to hold those which have duplicated keys
for y in ys.values():
zs[y] = y in zs
return ', '.join(
'"{}.{}"'.format(x.namespace, ys[x])
if zs[ys[x]] and getattr(x, 'namespace', None)
else '"{}"'.format(ys[x]) for x in xs
)
request_names = fmt(methods)
type_names = fmt(types)
constructor_names = fmt(constructors)
def fmt(xs, formatter):
return ', '.join('"{}"'.format(formatter(x)) for x in xs)
request_names = fmt(methods, get_class_name)
type_names = fmt(types, get_class_name)
constructor_names = fmt(constructors, get_class_name)
request_urls = fmt(methods, get_create_path_for)
type_urls = fmt(types, get_create_path_for)
type_urls = fmt(types, get_path_for_type)
constructor_urls = fmt(constructors, get_create_path_for)
replace_dict = {
@ -528,13 +554,15 @@ def generate_documentation(scheme_file):
'constructor_urls': constructor_urls
}
with open('../res/core.html') as infile:
with open(original_paths['index_all'], 'w') as outfile:
text = infile.read()
for key, value in replace_dict.items():
text = text.replace('{' + key + '}', str(value))
shutil.copy('../res/404.html', original_paths['404'])
outfile.write(text)
with open('../res/core.html') as infile,\
open(original_paths['index_all'], 'w') as outfile:
text = infile.read()
for key, value in replace_dict.items():
text = text.replace('{' + key + '}', str(value))
outfile.write(text)
# Everything done
print('Documentation generated.')
@ -551,5 +579,8 @@ def copy_resources():
if __name__ == '__main__':
os.makedirs('generated', exist_ok=True)
os.chdir('generated')
generate_documentation('../../telethon_generator/scheme.tl')
copy_resources()
try:
generate_documentation('../../telethon_generator/scheme.tl')
copy_resources()
finally:
os.chdir(os.pardir)

44
docs/res/404.html Normal file
View File

@ -0,0 +1,44 @@
<!DOCTYPE html>
<html><head>
<title>Oopsie! | Telethon</title>
<meta charset="utf-8">
<meta http-equiv="Content-type" content="text/html; charset=utf-8">
<meta name="viewport" content="width=device-width, initial-scale=1">
<style type="text/css">
body {
background-color: #f0f4f8;
font-family: "Open Sans", "Helvetica Neue", Helvetica, Arial, sans-serif;
}
div {
width: 560px;
margin: 5em auto;
padding: 50px;
background-color: #fff;
border-radius: 1em;
}
a:link, a:visited {
color: #38488f;
text-decoration: none;
}
@media (max-width: 700px) {
body {
background-color: #fff;
}
div {
width: auto;
margin: 0 auto;
border-radius: 0;
padding: 1em;
}
}
</style>
</head>
<body>
<div>
<h1>You seem a bit lost…</h1>
<p>You seem to be lost! Don't worry, that's just Telegram's API being
itself. Shall we go back to the <a href="index.html">Main Page</a>?</p>
</div>
</body>
</html>

View File

@ -12,34 +12,53 @@ Extra supported commands are:
"""
# To use a consistent encoding
from subprocess import run
from shutil import rmtree
from codecs import open
from sys import argv
from os import path
import os
# Always prefer setuptools over distutils
from setuptools import find_packages, setup
try:
from telethon import TelegramClient
except ImportError:
except Exception as e:
print('Failed to import TelegramClient due to', e)
TelegramClient = None
if __name__ == '__main__':
if len(argv) >= 2 and argv[1] == 'gen_tl':
from telethon_generator.tl_generator import TLGenerator
generator = TLGenerator('telethon/tl')
if generator.tlobjects_exist():
print('Detected previous TLObjects. Cleaning...')
generator.clean_tlobjects()
class TempWorkDir:
"""Switches the working directory to be the one on which this file lives,
while within the 'with' block.
"""
def __init__(self):
self.original = None
print('Generating TLObjects...')
generator.generate_tlobjects(
'telethon_generator/scheme.tl', import_depth=2
)
print('Done.')
def __enter__(self):
self.original = os.path.abspath(os.path.curdir)
os.chdir(os.path.abspath(os.path.dirname(__file__)))
return self
def __exit__(self, *args):
os.chdir(self.original)
def gen_tl():
from telethon_generator.tl_generator import TLGenerator
generator = TLGenerator('telethon/tl')
if generator.tlobjects_exist():
print('Detected previous TLObjects. Cleaning...')
generator.clean_tlobjects()
print('Generating TLObjects...')
generator.generate_tlobjects(
'telethon_generator/scheme.tl', import_depth=2
)
print('Done.')
def main():
if len(argv) >= 2 and argv[1] == 'gen_tl':
gen_tl()
elif len(argv) >= 2 and argv[1] == 'clean_tl':
from telethon_generator.tl_generator import TLGenerator
@ -48,6 +67,11 @@ if __name__ == '__main__':
print('Done.')
elif len(argv) >= 2 and argv[1] == 'pypi':
# Need python3.5 or higher, but Telethon is supposed to support 3.x
# Place it here since noone should be running ./setup.py pypi anyway
from subprocess import run
from shutil import rmtree
for x in ('build', 'dist', 'Telethon.egg-info'):
rmtree(x, ignore_errors=True)
run('python3 setup.py sdist', shell=True)
@ -58,20 +82,21 @@ if __name__ == '__main__':
else:
if not TelegramClient:
print('Run `python3', argv[0], 'gen_tl` first.')
quit()
here = path.abspath(path.dirname(__file__))
gen_tl()
from telethon import TelegramClient as TgClient
version = TgClient.__version__
else:
version = TelegramClient.__version__
# Get the long description from the README file
with open(path.join(here, 'README.rst'), encoding='utf-8') as f:
with open('README.rst', encoding='utf-8') as f:
long_description = f.read()
setup(
name='Telethon',
# Versions should comply with PEP440.
version=TelegramClient.__version__,
version=version,
description="Full-featured Telegram client library for Python 3",
long_description=long_description,
@ -108,3 +133,8 @@ if __name__ == '__main__':
]),
install_requires=['pyaes', 'rsa']
)
if __name__ == '__main__':
with TempWorkDir(): # Could just use a try/finally but this is + reusable
main()

View File

@ -1,7 +1,8 @@
import struct
from hashlib import sha1
from .. import helpers as utils
from ..extensions import BinaryReader, BinaryWriter
from ..extensions import BinaryReader
class AuthKey:
@ -17,10 +18,6 @@ class AuthKey:
"""Calculates the new nonce hash based on
the current class fields' values
"""
with BinaryWriter() as writer:
writer.write(new_nonce)
writer.write_byte(number)
writer.write_long(self.aux_hash, signed=False)
new_nonce_hash = utils.calc_msg_key(writer.get_bytes())
return new_nonce_hash
new_nonce = new_nonce.to_bytes(32, 'little', signed=True)
data = new_nonce + struct.pack('<BQ', number, self.aux_hash)
return utils.calc_msg_key(data)

View File

@ -10,56 +10,35 @@ from ..errors import CdnFileTamperedError
class CdnDecrypter:
"""Used when downloading a file results in a 'FileCdnRedirect' to
both prepare the redirect, decrypt the file as it downloads, and
ensure the file hasn't been tampered.
ensure the file hasn't been tampered. https://core.telegram.org/cdn
"""
def __init__(self, cdn_client, file_token, cdn_aes, cdn_file_hashes):
self.client = cdn_client
self.file_token = file_token
self.cdn_aes = cdn_aes
self.cdn_file_hashes = cdn_file_hashes
self.shaes = [sha256() for _ in range(len(cdn_file_hashes))]
@staticmethod
def prepare_decrypter(client, client_cls, cdn_redirect):
def prepare_decrypter(client, cdn_client, cdn_redirect):
"""Prepares a CDN decrypter, returning (decrypter, file data).
'client' should be the original TelegramBareClient that
tried to download the file.
'client_cls' should be the class of the TelegramBareClient.
'client' should be an existing client not connected to a CDN.
'cdn_client' should be an already-connected TelegramBareClient
with the auth key already created.
"""
# TODO Avoid the need for 'client_cls=TelegramBareClient'
# https://core.telegram.org/cdn
cdn_aes = AESModeCTR(
key=cdn_redirect.encryption_key,
# 12 first bytes of the IV..4 bytes of the offset (0, big endian)
iv=cdn_redirect.encryption_iv[:12] + bytes(4)
)
# Create a new client on said CDN
dc = client._get_dc(cdn_redirect.dc_id, cdn=True)
session = Session(client.session)
session.server_address = dc.ip_address
session.port = dc.port
cdn_client = client_cls( # Avoid importing TelegramBareClient
session, client.api_id, client.api_hash,
timeout=client._timeout
)
# This will make use of the new RSA keys for this specific CDN.
#
# We assume that cdn_redirect.cdn_file_hashes are ordered by offset,
# and that there will be enough of these to retrieve the whole file.
#
# This relies on the fact that TelegramBareClient._dc_options is
# static and it won't be called from this DC (it would fail).
cdn_client.connect()
# CDN client is ready, create the resulting CdnDecrypter
decrypter = CdnDecrypter(
cdn_client, cdn_redirect.file_token,
cdn_aes, cdn_redirect.cdn_file_hashes
)
cdn_file = client(GetCdnFileRequest(
cdn_file = cdn_client(GetCdnFileRequest(
file_token=cdn_redirect.file_token,
offset=cdn_redirect.cdn_file_hashes[0].offset,
limit=cdn_redirect.cdn_file_hashes[0].limit

View File

@ -1,4 +1,5 @@
import os
import struct
from hashlib import sha1
try:
import rsa
@ -7,7 +8,7 @@ except ImportError:
rsa = None
raise ImportError('Missing module "rsa", please install via pip.')
from ..extensions import BinaryWriter
from ..tl import TLObject
# {fingerprint: Crypto.PublicKey.RSA._RSAobj} dictionary
@ -34,11 +35,10 @@ def _compute_fingerprint(key):
"""For a given Crypto.RSA key, computes its 8-bytes-long fingerprint
in the same way that Telegram does.
"""
with BinaryWriter() as writer:
writer.tgwrite_bytes(get_byte_array(key.n))
writer.tgwrite_bytes(get_byte_array(key.e))
# Telegram uses the last 8 bytes as the fingerprint
return sha1(writer.get_bytes()).digest()[-8:]
n = TLObject.serialize_bytes(get_byte_array(key.n))
e = TLObject.serialize_bytes(get_byte_array(key.e))
# Telegram uses the last 8 bytes as the fingerprint
return struct.unpack('<q', sha1(n + e).digest()[-8:])[0]
def add_key(pub):

View File

@ -1,5 +1,6 @@
import urllib.request
import re
from threading import Thread
from .common import (
ReadCancelledError, InvalidParameterError, TypeNotFoundError,
@ -18,21 +19,29 @@ from .rpc_errors_401 import *
from .rpc_errors_420 import *
def report_error(code, message, report_method):
try:
# Ensure it's signed
report_method = int.from_bytes(
report_method.to_bytes(4, 'big'), 'big', signed=True
)
url = urllib.request.urlopen(
'https://rpc.pwrtelegram.xyz?code={}&error={}&method={}'
.format(code, message, report_method),
timeout=5
)
url.read()
url.close()
except:
"We really don't want to crash when just reporting an error"
def rpc_message_to_error(code, message, report_method=None):
if report_method is not None:
try:
# Ensure it's signed
report_method = int.from_bytes(
report_method.to_bytes(4, 'big'), 'big', signed=True
)
url = urllib.request.urlopen(
'https://rpc.pwrtelegram.xyz?code={}&error={}&method={}'
.format(code, message, report_method)
)
url.read()
url.close()
except:
"We really don't want to crash when just reporting an error"
Thread(
target=report_error,
args=(code, message, report_method)
).start()
errors = {
303: rpc_errors_303_all,

View File

@ -18,6 +18,15 @@ class BotMethodInvalidError(BadRequestError):
)
class CdnMethodInvalidError(BadRequestError):
def __init__(self, **kwargs):
super(Exception, self).__init__(
self,
'This method cannot be invoked on a CDN server. Refer to '
'https://core.telegram.org/cdn#schema for available methods.'
)
class ChannelInvalidError(BadRequestError):
def __init__(self, **kwargs):
super(Exception, self).__init__(
@ -134,6 +143,16 @@ class InputMethodInvalidError(BadRequestError):
)
class InputRequestTooLongError(BadRequestError):
def __init__(self, **kwargs):
super(Exception, self).__init__(
self,
'The input request was too long. This may be a bug in the library '
'as it can occur when serializing more bytes than it should (like'
'appending the vector constructor code at the end of a message).'
)
class LastNameInvalidError(BadRequestError):
def __init__(self, **kwargs):
super(Exception, self).__init__(
@ -142,6 +161,24 @@ class LastNameInvalidError(BadRequestError):
)
class LimitInvalidError(BadRequestError):
def __init__(self, **kwargs):
super(Exception, self).__init__(
self,
'An invalid limit was provided. See '
'https://core.telegram.org/api/files#downloading-files'
)
class LocationInvalidError(BadRequestError):
def __init__(self, **kwargs):
super(Exception, self).__init__(
self,
'The location given for a file was invalid. See '
'https://core.telegram.org/api/files#downloading-files'
)
class Md5ChecksumInvalidError(BadRequestError):
def __init__(self, **kwargs):
super(Exception, self).__init__(
@ -191,6 +228,16 @@ class MsgWaitFailedError(BadRequestError):
)
class OffsetInvalidError(BadRequestError):
def __init__(self, **kwargs):
super(Exception, self).__init__(
self,
'The given offset was invalid, it must be divisible by 1KB. '
'See https://core.telegram.org/api/files#downloading-files'
)
class PasswordHashInvalidError(BadRequestError):
def __init__(self, **kwargs):
super(Exception, self).__init__(
@ -350,6 +397,7 @@ class UserIdInvalidError(BadRequestError):
rpc_errors_400_all = {
'API_ID_INVALID': ApiIdInvalidError,
'BOT_METHOD_INVALID': BotMethodInvalidError,
'CDN_METHOD_INVALID': CdnMethodInvalidError,
'CHANNEL_INVALID': ChannelInvalidError,
'CHAT_ADMIN_REQUIRED': ChatAdminRequiredError,
'CHAT_ID_INVALID': ChatIdInvalidError,
@ -362,13 +410,17 @@ rpc_errors_400_all = {
'FILE_PART_INVALID': FilePartInvalidError,
'FIRSTNAME_INVALID': FirstNameInvalidError,
'INPUT_METHOD_INVALID': InputMethodInvalidError,
'INPUT_REQUEST_TOO_LONG': InputRequestTooLongError,
'LASTNAME_INVALID': LastNameInvalidError,
'LIMIT_INVALID': LimitInvalidError,
'LOCATION_INVALID': LocationInvalidError,
'MD5_CHECKSUM_INVALID': Md5ChecksumInvalidError,
'MESSAGE_EMPTY': MessageEmptyError,
'MESSAGE_ID_INVALID': MessageIdInvalidError,
'MESSAGE_TOO_LONG': MessageTooLongError,
'MESSAGE_NOT_MODIFIED': MessageNotModifiedError,
'MSG_WAIT_FAILED': MsgWaitFailedError,
'OFFSET_INVALID': OffsetInvalidError,
'PASSWORD_HASH_INVALID': PasswordHashInvalidError,
'PEER_ID_INVALID': PeerIdInvalidError,
'PHONE_CODE_EMPTY': PhoneCodeEmptyError,

View File

@ -1,9 +1,7 @@
"""
Several extensions Python is missing, such as a proper class to handle a TCP
communication with support for cancelling the operation, and an utility class
to work with arbitrary binary data in a more comfortable way (writing ints,
strings, bytes, etc.)
to read arbitrary binary data in a more comfortable way, with int/strings/etc.
"""
from .binary_writer import BinaryWriter
from .binary_reader import BinaryReader
from .tcp_client import TcpClient

View File

@ -1,152 +0,0 @@
from io import BufferedWriter, BytesIO, DEFAULT_BUFFER_SIZE
from struct import pack
class BinaryWriter:
"""
Small utility class to write binary data.
Also creates a "Memory Stream" if necessary
"""
def __init__(self, stream=None, known_length=None):
if not stream:
stream = BytesIO()
if known_length is None:
# On some systems, DEFAULT_BUFFER_SIZE defaults to 8192
# That's over 16 times as big as necessary for most messages
known_length = max(DEFAULT_BUFFER_SIZE, 1024)
self.writer = BufferedWriter(stream, buffer_size=known_length)
self.written_count = 0
# region Writing
# "All numbers are written as little endian."
# https://core.telegram.org/mtproto
def write_byte(self, value):
"""Writes a single byte value"""
self.writer.write(pack('B', value))
self.written_count += 1
def write_int(self, value, signed=True):
"""Writes an integer value (4 bytes), optionally signed"""
self.writer.write(
int.to_bytes(
value, length=4, byteorder='little', signed=signed))
self.written_count += 4
def write_long(self, value, signed=True):
"""Writes a long integer value (8 bytes), optionally signed"""
self.writer.write(
int.to_bytes(
value, length=8, byteorder='little', signed=signed))
self.written_count += 8
def write_float(self, value):
"""Writes a floating point value (4 bytes)"""
self.writer.write(pack('<f', value))
self.written_count += 4
def write_double(self, value):
"""Writes a floating point value (8 bytes)"""
self.writer.write(pack('<d', value))
self.written_count += 8
def write_large_int(self, value, bits, signed=True):
"""Writes a n-bits long integer value"""
self.writer.write(
int.to_bytes(
value, length=bits // 8, byteorder='little', signed=signed))
self.written_count += bits // 8
def write(self, data):
"""Writes the given bytes array"""
self.writer.write(data)
self.written_count += len(data)
# endregion
# region Telegram custom writing
def tgwrite_bytes(self, data):
"""Write bytes by using Telegram guidelines"""
if len(data) < 254:
padding = (len(data) + 1) % 4
if padding != 0:
padding = 4 - padding
self.write(bytes([len(data)]))
self.write(data)
else:
padding = len(data) % 4
if padding != 0:
padding = 4 - padding
self.write(bytes([254]))
self.write(bytes([len(data) % 256]))
self.write(bytes([(len(data) >> 8) % 256]))
self.write(bytes([(len(data) >> 16) % 256]))
self.write(data)
self.write(bytes(padding))
def tgwrite_string(self, string):
"""Write a string by using Telegram guidelines"""
self.tgwrite_bytes(string.encode('utf-8'))
def tgwrite_bool(self, boolean):
"""Write a boolean value by using Telegram guidelines"""
# boolTrue boolFalse
self.write_int(0x997275b5 if boolean else 0xbc799737, signed=False)
def tgwrite_date(self, datetime):
"""Converts a Python datetime object into Unix time
(used by Telegram) and writes it
"""
value = 0 if datetime is None else int(datetime.timestamp())
self.write_int(value)
def tgwrite_object(self, tlobject):
"""Writes a Telegram object"""
tlobject.on_send(self)
def tgwrite_vector(self, vector):
"""Writes a vector of Telegram objects"""
self.write_int(0x1cb5c415, signed=False) # Vector's constructor ID
self.write_int(len(vector))
for item in vector:
self.tgwrite_object(item)
# endregion
def flush(self):
"""Flush the current stream to "update" changes"""
self.writer.flush()
def close(self):
"""Close the current stream"""
self.writer.close()
def get_bytes(self, flush=True):
"""Get the current bytes array content from the buffer,
optionally flushing first
"""
if flush:
self.writer.flush()
return self.writer.raw.getvalue()
def get_written_bytes_count(self):
"""Gets the count of bytes written in the buffer.
This may NOT be equal to the stream length if one
was provided when initializing the writer
"""
return self.written_count
# with block
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.close()

View File

@ -8,41 +8,55 @@ from threading import Lock
class TcpClient:
def __init__(self, proxy=None, timeout=timedelta(seconds=5)):
self._proxy = proxy
self.proxy = proxy
self._socket = None
self._closing_lock = Lock()
if isinstance(timeout, timedelta):
self._timeout = timeout.seconds
self.timeout = timeout.seconds
elif isinstance(timeout, int) or isinstance(timeout, float):
self._timeout = float(timeout)
self.timeout = float(timeout)
else:
raise ValueError('Invalid timeout type', type(timeout))
def _recreate_socket(self, mode):
if self._proxy is None:
if self.proxy is None:
self._socket = socket.socket(mode, socket.SOCK_STREAM)
else:
import socks
self._socket = socks.socksocket(mode, socket.SOCK_STREAM)
if type(self._proxy) is dict:
self._socket.set_proxy(**self._proxy)
if type(self.proxy) is dict:
self._socket.set_proxy(**self.proxy)
else: # tuple, list, etc.
self._socket.set_proxy(*self._proxy)
self._socket.set_proxy(*self.proxy)
self._socket.settimeout(self.timeout)
def connect(self, ip, port):
"""Connects to the specified IP and port number.
'timeout' must be given in seconds
"""
if not self.connected:
if ':' in ip: # IPv6
mode, address = socket.AF_INET6, (ip, port, 0, 0)
else:
mode, address = socket.AF_INET, (ip, port)
if ':' in ip: # IPv6
mode, address = socket.AF_INET6, (ip, port, 0, 0)
else:
mode, address = socket.AF_INET, (ip, port)
self._recreate_socket(mode)
self._socket.settimeout(self._timeout)
self._socket.connect(address)
while True:
try:
while not self._socket:
self._recreate_socket(mode)
self._socket.connect(address)
break # Successful connection, stop retrying to connect
except OSError as e:
# There are some errors that we know how to handle, and
# the loop will allow us to retry
if e.errno == errno.EBADF:
# Bad file descriptor, i.e. socket was closed, set it
# to none to recreate it on the next iteration
self._socket = None
else:
raise
def _get_connected(self):
return self._socket is not None
@ -67,6 +81,8 @@ class TcpClient:
def write(self, data):
"""Writes (sends) the specified bytes to the connected peer"""
if self._socket is None:
raise ConnectionResetError()
# TODO Timeout may be an issue when sending the data, Changed in v3.5:
# The socket timeout is now the maximum total duration to send all data.
@ -74,13 +90,13 @@ class TcpClient:
self._socket.sendall(data)
except socket.timeout as e:
raise TimeoutError() from e
except BrokenPipeError:
self._raise_connection_reset()
except OSError as e:
if e.errno == errno.EBADF:
self._raise_connection_reset()
else:
raise
except BrokenPipeError:
self._raise_connection_reset()
def read(self, size):
"""Reads (receives) a whole block of 'size bytes
@ -91,6 +107,9 @@ class TcpClient:
and it's waiting for more, the timeout will NOT cancel the
operation. Set to None for no timeout
"""
if self._socket is None:
raise ConnectionResetError()
# TODO Remove the timeout from this method, always use previous one
with BufferedWriter(BytesIO(), buffer_size=size) as buffer:
bytes_left = size
@ -100,7 +119,7 @@ class TcpClient:
except socket.timeout as e:
raise TimeoutError() from e
except OSError as e:
if e.errno == errno.EBADF:
if e.errno == errno.EBADF or e.errno == errno.ENOTSOCK:
self._raise_connection_reset()
else:
raise

View File

@ -47,9 +47,11 @@ def calc_msg_key(data):
def generate_key_data_from_nonce(server_nonce, new_nonce):
"""Generates the key data corresponding to the given nonce"""
hash1 = sha1(bytes(new_nonce + server_nonce)).digest()
hash2 = sha1(bytes(server_nonce + new_nonce)).digest()
hash3 = sha1(bytes(new_nonce + new_nonce)).digest()
server_nonce = server_nonce.to_bytes(16, 'little', signed=True)
new_nonce = new_nonce.to_bytes(32, 'little', signed=True)
hash1 = sha1(new_nonce + server_nonce).digest()
hash2 = sha1(server_nonce + new_nonce).digest()
hash3 = sha1(new_nonce + new_nonce).digest()
key = hash1 + hash2[:12]
iv = hash2[12:20] + hash3 + new_nonce[:4]

View File

@ -2,12 +2,19 @@ import os
import time
from hashlib import sha1
from ..tl.types import (
ResPQ, PQInnerData, ServerDHParamsFail, ServerDHParamsOk,
ServerDHInnerData, ClientDHInnerData, DhGenOk, DhGenRetry, DhGenFail
)
from .. import helpers as utils
from ..crypto import AES, AuthKey, Factorization
from ..crypto import rsa
from ..errors import SecurityError, TypeNotFoundError
from ..extensions import BinaryReader, BinaryWriter
from ..errors import SecurityError
from ..extensions import BinaryReader
from ..network import MtProtoPlainSender
from ..tl.functions import (
ReqPqRequest, ReqDHParamsRequest, SetClientDHParamsRequest
)
def do_authentication(connection, retries=5):
@ -18,7 +25,7 @@ def do_authentication(connection, retries=5):
while retries:
try:
return _do_authentication(connection)
except (SecurityError, TypeNotFoundError, NotImplementedError) as e:
except (SecurityError, AssertionError, NotImplementedError) as e:
last_error = e
retries -= 1
raise last_error
@ -30,202 +37,158 @@ def _do_authentication(connection):
time offset.
"""
sender = MtProtoPlainSender(connection)
sender.connect()
# Step 1 sending: PQ Request
nonce = os.urandom(16)
with BinaryWriter(known_length=20) as writer:
writer.write_int(0x60469778, signed=False) # Constructor number
writer.write(nonce)
sender.send(writer.get_bytes())
# Step 1 response: PQ Request
pq, pq_bytes, server_nonce, fingerprints = None, None, None, []
# Step 1 sending: PQ Request, endianness doesn't matter since it's random
req_pq_request = ReqPqRequest(
nonce=int.from_bytes(os.urandom(16), 'big', signed=True)
)
sender.send(req_pq_request.to_bytes())
with BinaryReader(sender.receive()) as reader:
response_code = reader.read_int(signed=False)
if response_code != 0x05162463:
raise TypeNotFoundError(response_code)
req_pq_request.on_response(reader)
nonce_from_server = reader.read(16)
if nonce_from_server != nonce:
raise SecurityError('Invalid nonce from server')
res_pq = req_pq_request.result
if not isinstance(res_pq, ResPQ):
raise AssertionError(res_pq)
server_nonce = reader.read(16)
if res_pq.nonce != req_pq_request.nonce:
raise SecurityError('Invalid nonce from server')
pq_bytes = reader.tgread_bytes()
pq = get_int(pq_bytes)
vector_id = reader.read_int()
if vector_id != 0x1cb5c415:
raise TypeNotFoundError(response_code)
fingerprints = []
fingerprint_count = reader.read_int()
for _ in range(fingerprint_count):
fingerprints.append(reader.read(8))
pq = get_int(res_pq.pq)
# Step 2 sending: DH Exchange
new_nonce = os.urandom(32)
p, q = Factorization.factorize(pq)
with BinaryWriter() as pq_inner_data_writer:
pq_inner_data_writer.write_int(
0x83c95aec, signed=False) # PQ Inner Data
pq_inner_data_writer.tgwrite_bytes(rsa.get_byte_array(pq))
pq_inner_data_writer.tgwrite_bytes(rsa.get_byte_array(min(p, q)))
pq_inner_data_writer.tgwrite_bytes(rsa.get_byte_array(max(p, q)))
pq_inner_data_writer.write(nonce)
pq_inner_data_writer.write(server_nonce)
pq_inner_data_writer.write(new_nonce)
p, q = rsa.get_byte_array(min(p, q)), rsa.get_byte_array(max(p, q))
new_nonce = int.from_bytes(os.urandom(32), 'little', signed=True)
# sha_digest + data + random_bytes
cipher_text, target_fingerprint = None, None
for fingerprint in fingerprints:
cipher_text = rsa.encrypt(
fingerprint,
pq_inner_data_writer.get_bytes()
pq_inner_data = PQInnerData(
pq=rsa.get_byte_array(pq), p=p, q=q,
nonce=res_pq.nonce,
server_nonce=res_pq.server_nonce,
new_nonce=new_nonce
).to_bytes()
# sha_digest + data + random_bytes
cipher_text, target_fingerprint = None, None
for fingerprint in res_pq.server_public_key_fingerprints:
cipher_text = rsa.encrypt(fingerprint, pq_inner_data)
if cipher_text is not None:
target_fingerprint = fingerprint
break
if cipher_text is None:
raise SecurityError(
'Could not find a valid key for fingerprints: {}'
.format(', '.join(
[str(f) for f in res_pq.server_public_key_fingerprints])
)
)
if cipher_text is not None:
target_fingerprint = fingerprint
break
if cipher_text is None:
raise SecurityError(
'Could not find a valid key for fingerprints: {}'
.format(', '.join([repr(f) for f in fingerprints]))
)
with BinaryWriter() as req_dh_params_writer:
req_dh_params_writer.write_int(
0xd712e4be, signed=False) # Req DH Params
req_dh_params_writer.write(nonce)
req_dh_params_writer.write(server_nonce)
req_dh_params_writer.tgwrite_bytes(rsa.get_byte_array(min(p, q)))
req_dh_params_writer.tgwrite_bytes(rsa.get_byte_array(max(p, q)))
req_dh_params_writer.write(target_fingerprint)
req_dh_params_writer.tgwrite_bytes(cipher_text)
req_dh_params_bytes = req_dh_params_writer.get_bytes()
sender.send(req_dh_params_bytes)
req_dh_params = ReqDHParamsRequest(
nonce=res_pq.nonce,
server_nonce=res_pq.server_nonce,
p=p, q=q,
public_key_fingerprint=target_fingerprint,
encrypted_data=cipher_text
)
sender.send(req_dh_params.to_bytes())
# Step 2 response: DH Exchange
encrypted_answer = None
with BinaryReader(sender.receive()) as reader:
response_code = reader.read_int(signed=False)
req_dh_params.on_response(reader)
if response_code == 0x79cb045d:
raise SecurityError('Server DH params fail: TODO')
server_dh_params = req_dh_params.result
if isinstance(server_dh_params, ServerDHParamsFail):
raise SecurityError('Server DH params fail: TODO')
if response_code != 0xd0e8075c:
raise TypeNotFoundError(response_code)
if not isinstance(server_dh_params, ServerDHParamsOk):
raise AssertionError(server_dh_params)
nonce_from_server = reader.read(16)
if nonce_from_server != nonce:
raise SecurityError('Invalid nonce from server')
if server_dh_params.nonce != res_pq.nonce:
raise SecurityError('Invalid nonce from server')
server_nonce_from_server = reader.read(16)
if server_nonce_from_server != server_nonce:
raise SecurityError('Invalid server nonce from server')
encrypted_answer = reader.tgread_bytes()
if server_dh_params.server_nonce != res_pq.server_nonce:
raise SecurityError('Invalid server nonce from server')
# Step 3 sending: Complete DH Exchange
key, iv = utils.generate_key_data_from_nonce(server_nonce, new_nonce)
plain_text_answer = AES.decrypt_ige(encrypted_answer, key, iv)
key, iv = utils.generate_key_data_from_nonce(
res_pq.server_nonce, new_nonce
)
plain_text_answer = AES.decrypt_ige(
server_dh_params.encrypted_answer, key, iv
)
g, dh_prime, ga, time_offset = None, None, None, None
with BinaryReader(plain_text_answer) as dh_inner_data_reader:
dh_inner_data_reader.read(20) # hash sum
code = dh_inner_data_reader.read_int(signed=False)
if code != 0xb5890dba:
raise TypeNotFoundError(code)
with BinaryReader(plain_text_answer) as reader:
reader.read(20) # hash sum
server_dh_inner = reader.tgread_object()
if not isinstance(server_dh_inner, ServerDHInnerData):
raise AssertionError(server_dh_inner)
nonce_from_server1 = dh_inner_data_reader.read(16)
if nonce_from_server1 != nonce:
raise SecurityError('Invalid nonce in encrypted answer')
if server_dh_inner.nonce != res_pq.nonce:
print(server_dh_inner.nonce, res_pq.nonce)
raise SecurityError('Invalid nonce in encrypted answer')
server_nonce_from_server1 = dh_inner_data_reader.read(16)
if server_nonce_from_server1 != server_nonce:
raise SecurityError('Invalid server nonce in encrypted answer')
if server_dh_inner.server_nonce != res_pq.server_nonce:
raise SecurityError('Invalid server nonce in encrypted answer')
g = dh_inner_data_reader.read_int()
dh_prime = get_int(dh_inner_data_reader.tgread_bytes(), signed=False)
ga = get_int(dh_inner_data_reader.tgread_bytes(), signed=False)
server_time = dh_inner_data_reader.read_int()
time_offset = server_time - int(time.time())
dh_prime = get_int(server_dh_inner.dh_prime, signed=False)
g_a = get_int(server_dh_inner.g_a, signed=False)
time_offset = server_dh_inner.server_time - int(time.time())
b = get_int(os.urandom(256), signed=False)
gb = pow(g, b, dh_prime)
gab = pow(ga, b, dh_prime)
gb = pow(server_dh_inner.g, b, dh_prime)
gab = pow(g_a, b, dh_prime)
# Prepare client DH Inner Data
with BinaryWriter() as client_dh_inner_data_writer:
client_dh_inner_data_writer.write_int(
0x6643b654, signed=False) # Client DH Inner Data
client_dh_inner_data_writer.write(nonce)
client_dh_inner_data_writer.write(server_nonce)
client_dh_inner_data_writer.write_long(0) # TODO retry_id
client_dh_inner_data_writer.tgwrite_bytes(rsa.get_byte_array(gb))
client_dh_inner = ClientDHInnerData(
nonce=res_pq.nonce,
server_nonce=res_pq.server_nonce,
retry_id=0, # TODO Actual retry ID
g_b=rsa.get_byte_array(gb)
).to_bytes()
with BinaryWriter() as client_dh_inner_data_with_hash_writer:
client_dh_inner_data_with_hash_writer.write(
sha1(client_dh_inner_data_writer.get_bytes()).digest())
client_dh_inner_data_with_hash_writer.write(
client_dh_inner_data_writer.get_bytes())
client_dh_inner_data_bytes = \
client_dh_inner_data_with_hash_writer.get_bytes()
client_dh_inner_hashed = sha1(client_dh_inner).digest() + client_dh_inner
# Encryption
client_dh_inner_data_encrypted_bytes = AES.encrypt_ige(
client_dh_inner_data_bytes, key, iv)
client_dh_encrypted = AES.encrypt_ige(client_dh_inner_hashed, key, iv)
# Prepare Set client DH params
with BinaryWriter() as set_client_dh_params_writer:
set_client_dh_params_writer.write_int(0xf5045f1f, signed=False)
set_client_dh_params_writer.write(nonce)
set_client_dh_params_writer.write(server_nonce)
set_client_dh_params_writer.tgwrite_bytes(
client_dh_inner_data_encrypted_bytes)
set_client_dh_params_bytes = set_client_dh_params_writer.get_bytes()
sender.send(set_client_dh_params_bytes)
set_client_dh = SetClientDHParamsRequest(
nonce=res_pq.nonce,
server_nonce=res_pq.server_nonce,
encrypted_data=client_dh_encrypted,
)
sender.send(set_client_dh.to_bytes())
# Step 3 response: Complete DH Exchange
with BinaryReader(sender.receive()) as reader:
# Everything read from the server, disconnect now
sender.disconnect()
set_client_dh.on_response(reader)
code = reader.read_int(signed=False)
if code == 0x3bcbf734: # DH Gen OK
nonce_from_server = reader.read(16)
if nonce_from_server != nonce:
raise SecurityError('Invalid nonce from server')
dh_gen = set_client_dh.result
if isinstance(dh_gen, DhGenOk):
if dh_gen.nonce != res_pq.nonce:
raise SecurityError('Invalid nonce from server')
server_nonce_from_server = reader.read(16)
if server_nonce_from_server != server_nonce:
raise SecurityError('Invalid server nonce from server')
if dh_gen.server_nonce != res_pq.server_nonce:
raise SecurityError('Invalid server nonce from server')
new_nonce_hash1 = reader.read(16)
auth_key = AuthKey(rsa.get_byte_array(gab))
auth_key = AuthKey(rsa.get_byte_array(gab))
new_nonce_hash = int.from_bytes(
auth_key.calc_new_nonce_hash(new_nonce, 1), 'little', signed=True
)
new_nonce_hash_calculated = auth_key.calc_new_nonce_hash(new_nonce,
1)
if new_nonce_hash1 != new_nonce_hash_calculated:
raise SecurityError('Invalid new nonce hash')
if dh_gen.new_nonce_hash1 != new_nonce_hash:
raise SecurityError('Invalid new nonce hash')
return auth_key, time_offset
return auth_key, time_offset
elif code == 0x46dc1fb9: # DH Gen Retry
raise NotImplementedError('dh_gen_retry')
elif isinstance(dh_gen, DhGenRetry):
raise NotImplementedError('DhGenRetry')
elif code == 0xa69dae02: # DH Gen Fail
raise NotImplementedError('dh_gen_fail')
elif isinstance(dh_gen, DhGenFail):
raise NotImplementedError('DhGenFail')
else:
raise NotImplementedError('DH Gen unknown: {}'.format(hex(code)))
else:
raise NotImplementedError('DH Gen unknown: {}'.format(dh_gen))
def get_int(byte_array, signed=True):

View File

@ -1,10 +1,13 @@
import os
import struct
from datetime import timedelta
from zlib import crc32
from enum import Enum
import errno
from ..crypto import AESModeCTR
from ..extensions import BinaryWriter, TcpClient
from ..extensions import TcpClient
from ..errors import InvalidChecksumError
@ -75,9 +78,15 @@ class Connection:
setattr(self, 'read', self._read_plain)
def connect(self):
self._send_counter = 0
self.conn.connect(self.ip, self.port)
try:
self.conn.connect(self.ip, self.port)
except OSError as e:
if e.errno == errno.EISCONN:
return # Already connected, no need to re-set everything up
else:
raise
self._send_counter = 0
if self._mode == ConnectionMode.TCP_ABRIDGED:
self.conn.write(b'\xef')
elif self._mode == ConnectionMode.TCP_INTERMEDIATE:
@ -85,6 +94,9 @@ class Connection:
elif self._mode == ConnectionMode.TCP_OBFUSCATED:
self._setup_obfuscation()
def get_timeout(self):
return self.conn.timeout
def _setup_obfuscation(self):
# Obfuscated messages secrets cannot start with any of these
keywords = (b'PVrG', b'GET ', b'POST', b'\xee' * 4)
@ -118,6 +130,13 @@ class Connection:
def close(self):
self.conn.close()
def clone(self):
"""Creates a copy of this Connection"""
return Connection(self.ip, self.port,
mode=self._mode,
proxy=self.conn.proxy,
timeout=self.conn.timeout)
# region Receive message implementations
def recv(self):
@ -164,30 +183,22 @@ class Connection:
# https://core.telegram.org/mtproto#tcp-transport
# total length, sequence number, packet and checksum (CRC32)
length = len(message) + 12
with BinaryWriter(known_length=length) as writer:
writer.write_int(length)
writer.write_int(self._send_counter)
writer.write(message)
writer.write_int(crc32(writer.get_bytes()), signed=False)
self._send_counter += 1
self.write(writer.get_bytes())
data = struct.pack('<ii', length, self._send_counter) + message
crc = struct.pack('<I', crc32(data))
self._send_counter += 1
self.write(data + crc)
def _send_intermediate(self, message):
with BinaryWriter(known_length=len(message) + 4) as writer:
writer.write_int(len(message))
writer.write(message)
self.write(writer.get_bytes())
self.write(struct.pack('<i', len(message)) + message)
def _send_abridged(self, message):
with BinaryWriter(known_length=len(message) + 4) as writer:
length = len(message) >> 2
if length < 127:
writer.write_byte(length)
else:
writer.write_byte(127)
writer.write(int.to_bytes(length, 3, 'little'))
writer.write(message)
self.write(writer.get_bytes())
length = len(message) >> 2
if length < 127:
length = struct.pack('B', length)
else:
length = b'\x7f' + int.to_bytes(length, 3, 'little')
self.write(length + message)
# endregion

View File

@ -1,7 +1,8 @@
import struct
import time
from ..errors import BrokenAuthKeyError
from ..extensions import BinaryReader, BinaryWriter
from ..extensions import BinaryReader
class MtProtoPlainSender:
@ -25,14 +26,9 @@ class MtProtoPlainSender:
"""Sends a plain packet (auth_key_id = 0) containing the
given message body (data)
"""
with BinaryWriter(known_length=len(data) + 20) as writer:
writer.write_long(0)
writer.write_long(self._get_new_msg_id())
writer.write_int(len(data))
writer.write(data)
packet = writer.get_bytes()
self._connection.send(packet)
self._connection.send(
struct.pack('<QQi', 0, self._get_new_msg_id(), len(data)) + data
)
def receive(self):
"""Receives a plain packet, returning the body of the response"""

View File

@ -1,41 +1,45 @@
import gzip
import logging
from threading import RLock
import struct
from .. import helpers as utils
from ..crypto import AES
from ..errors import BadMessageError, InvalidChecksumError, rpc_message_to_error
from ..extensions import BinaryReader, BinaryWriter
from ..errors import (
BadMessageError, InvalidChecksumError, BrokenAuthKeyError,
rpc_message_to_error
)
from ..extensions import BinaryReader
from ..tl import TLMessage, MessageContainer, GzipPacked
from ..tl.all_tlobjects import tlobjects
from ..tl.types import MsgsAck
from ..tl.functions.auth import LogOutRequest
logging.getLogger(__name__).addHandler(logging.NullHandler())
class MtProtoSender:
"""MTProto Mobile Protocol sender
(https://core.telegram.org/mtproto/description)
(https://core.telegram.org/mtproto/description).
Note that this class is not thread-safe, and calling send/receive
from two or more threads at the same time is undefined behaviour.
Rationale: a new connection should be spawned to send/receive requests
in parallel, so thread-safety (hence locking) isn't needed.
"""
def __init__(self, connection, session):
def __init__(self, session, connection):
"""Creates a new MtProtoSender configured to send messages through
'connection' and using the parameters from 'session'.
"""
self.connection = connection
self.session = session
self.connection = connection
self._logger = logging.getLogger(__name__)
self._need_confirmation = [] # Message IDs that need confirmation
self._pending_receive = [] # Requests sent waiting to be received
# Message IDs that need confirmation
self._need_confirmation = []
# Sending and receiving are independent, but two threads cannot
# send or receive at the same time no matter what.
self._send_lock = RLock()
self._recv_lock = RLock()
# Used when logging out, the only request that seems to use 'ack'
# TODO There might be a better way to handle msgs_ack requests
self.logging_out = False
# Requests (as msg_id: Message) sent waiting to be received
self._pending_receive = {}
def connect(self):
"""Connects to the server"""
@ -47,33 +51,39 @@ class MtProtoSender:
def disconnect(self):
"""Disconnects from the server"""
self.connection.close()
self._need_confirmation.clear()
self._clear_all_pending()
def clone(self):
"""Creates a copy of this MtProtoSender as a new connection"""
return MtProtoSender(self.session, self.connection.clone())
# region Send and receive
def send(self, request):
def send(self, *requests):
"""Sends the specified MTProtoRequest, previously sending any message
which needed confirmation."""
# If any message needs confirmation send an AckRequest first
self._send_acknowledges()
# Finally send our packed request
with BinaryWriter() as writer:
request.on_send(writer)
self._send_packet(writer.get_bytes(), request)
self._pending_receive.append(request)
# Finally send our packed request(s)
messages = [TLMessage(self.session, r) for r in requests]
self._pending_receive.update({m.msg_id: m for m in messages})
# And update the saved session
self.session.save()
if len(messages) == 1:
message = messages[0]
else:
message = TLMessage(self.session, MessageContainer(messages))
self._send_message(message)
def _send_acknowledges(self):
"""Sends a messages acknowledge for all those who _need_confirmation"""
if self._need_confirmation:
msgs_ack = MsgsAck(self._need_confirmation)
with BinaryWriter() as writer:
msgs_ack.on_send(writer)
self._send_packet(writer.get_bytes(), msgs_ack)
self._send_message(
TLMessage(self.session, MsgsAck(self._need_confirmation))
)
del self._need_confirmation[:]
def receive(self, update_state):
@ -86,21 +96,18 @@ class MtProtoSender:
Any unhandled object (likely updates) will be passed to
update_state.process(TLObject).
"""
with self._recv_lock:
try:
body = self.connection.recv()
except (BufferError, InvalidChecksumError):
# TODO BufferError, we should spot the cause...
# "No more bytes left"; something wrong happened, clear
# everything to be on the safe side, or:
#
# "This packet should be skipped"; since this may have
# been a result for a request, invalidate every request
# and just re-invoke them to avoid problems
for r in self._pending_receive:
r.confirm_received.set()
self._pending_receive.clear()
return
try:
body = self.connection.recv()
except (BufferError, InvalidChecksumError):
# TODO BufferError, we should spot the cause...
# "No more bytes left"; something wrong happened, clear
# everything to be on the safe side, or:
#
# "This packet should be skipped"; since this may have
# been a result for a request, invalidate every request
# and just re-invoke them to avoid problems
self._clear_all_pending()
return
message, remote_msg_id, remote_seq = self._decode_msg(body)
with BinaryReader(message) as reader:
@ -110,36 +117,20 @@ class MtProtoSender:
# region Low level processing
def _send_packet(self, packet, request):
"""Sends the given packet bytes with the additional
information of the original request.
"""
request.request_msg_id = self.session.get_new_msg_id()
def _send_message(self, message):
"""Sends the given Message(TLObject) encrypted through the network"""
# First calculate plain_text to encrypt it
with BinaryWriter() as plain_writer:
plain_writer.write_long(self.session.salt, signed=False)
plain_writer.write_long(self.session.id, signed=False)
plain_writer.write_long(request.request_msg_id)
plain_writer.write_int(
self.session.generate_sequence(request.content_related))
plain_text = \
struct.pack('<QQ', self.session.salt, self.session.id) \
+ message.to_bytes()
plain_writer.write_int(len(packet))
plain_writer.write(packet)
msg_key = utils.calc_msg_key(plain_text)
key_id = struct.pack('<Q', self.session.auth_key.key_id)
key, iv = utils.calc_key(self.session.auth_key.key, msg_key, True)
cipher_text = AES.encrypt_ige(plain_text, key, iv)
msg_key = utils.calc_msg_key(plain_writer.get_bytes())
key, iv = utils.calc_key(self.session.auth_key.key, msg_key, True)
cipher_text = AES.encrypt_ige(plain_writer.get_bytes(), key, iv)
# And then finally send the encrypted packet
with BinaryWriter() as cipher_writer:
cipher_writer.write_long(
self.session.auth_key.key_id, signed=False)
cipher_writer.write(msg_key)
cipher_writer.write(cipher_text)
with self._send_lock:
self.connection.send(cipher_writer.get_bytes())
result = key_id + msg_key + cipher_text
self.connection.send(result)
def _decode_msg(self, body):
"""Decodes an received encrypted message body bytes"""
@ -149,7 +140,10 @@ class MtProtoSender:
with BinaryReader(body) as reader:
if len(body) < 8:
raise BufferError("Can't decode packet ({})".format(body))
if body == b'l\xfe\xff\xff':
raise BrokenAuthKeyError()
else:
raise BufferError("Can't decode packet ({})".format(body))
# TODO Check for both auth key ID and msg_key correctness
reader.read_long() # remote_auth_key_id
@ -204,14 +198,15 @@ class MtProtoSender:
# msgs_ack, it may handle the request we wanted
if code == 0x62d6b459:
ack = reader.tgread_object()
for r in self._pending_receive:
if r.request_msg_id in ack.msg_ids:
self._logger.debug('Ack found for the a request')
if self.logging_out:
self._logger.debug('Message ack confirmed a request')
self._pending_receive.remove(r)
r.confirm_received.set()
# Ignore every ack request *unless* when logging out, when it's
# when it seems to only make sense. We also need to set a non-None
# result since Telegram doesn't send the response for these.
for msg_id in ack.msg_ids:
r = self._pop_request_of_type(msg_id, LogOutRequest)
if r:
r.result = True # Telegram won't send this value
r.confirm_received()
self._logger.debug('Message ack confirmed', r)
return True
@ -237,13 +232,26 @@ class MtProtoSender:
# region Message handling
def _pop_request(self, request_msg_id):
"""Pops a pending request from self._pending_receive, or
returns None if it's not found
def _pop_request(self, msg_id):
"""Pops a pending REQUEST from self._pending_receive, or
returns None if it's not found.
"""
for i in range(len(self._pending_receive)):
if self._pending_receive[i].request_msg_id == request_msg_id:
return self._pending_receive.pop(i)
message = self._pending_receive.pop(msg_id, None)
if message:
return message.request
def _pop_request_of_type(self, msg_id, t):
"""Pops a pending REQUEST from self._pending_receive if it matches
the given type, or returns None if it's not found/doesn't match.
"""
message = self._pending_receive.get(msg_id, None)
if isinstance(message.request, t):
return self._pending_receive.pop(msg_id).request
def _clear_all_pending(self):
for r in self._pending_receive.values():
r.confirm_received.set()
self._pending_receive.clear()
def _handle_pong(self, msg_id, sequence, reader):
self._logger.debug('Handling pong')
@ -259,22 +267,17 @@ class MtProtoSender:
def _handle_container(self, msg_id, sequence, reader, state):
self._logger.debug('Handling container')
reader.read_int(signed=False) # code
size = reader.read_int()
for _ in range(size):
inner_msg_id = reader.read_long()
reader.read_int() # inner_sequence
inner_length = reader.read_int()
for inner_msg_id, _, inner_len in MessageContainer.iter_read(reader):
begin_position = reader.tell_position()
# Note that this code is IMPORTANT for skipping RPC results of
# lost requests (i.e., ones from the previous connection session)
try:
if not self._process_msg(inner_msg_id, sequence, reader, state):
reader.set_position(begin_position + inner_length)
reader.set_position(begin_position + inner_len)
except:
# If any error is raised, something went wrong; skip the packet
reader.set_position(begin_position + inner_length)
reader.set_position(begin_position + inner_len)
raise
return True
@ -306,7 +309,6 @@ class MtProtoSender:
# sent msg_id too low or too high (respectively).
# Use the current msg_id to determine the right time offset.
self.session.update_time_offset(correct_msg_id=msg_id)
self.session.save()
self._logger.debug('Read Bad Message error: ' + str(error))
self._logger.debug('Attempting to use the correct time offset.')
return True
@ -334,7 +336,7 @@ class MtProtoSender:
if self.session.report_errors and request:
error = rpc_message_to_error(
reader.read_int(), reader.tgread_string(),
report_method=type(request).constructor_id
report_method=type(request).CONSTRUCTOR_ID
)
else:
error = rpc_message_to_error(
@ -372,11 +374,7 @@ class MtProtoSender:
def _handle_gzip_packed(self, msg_id, sequence, reader, state):
self._logger.debug('Handling gzip packed data')
reader.read_int(signed=False) # code
packed_data = reader.tgread_bytes()
unpacked_data = gzip.decompress(packed_data)
with BinaryReader(unpacked_data) as compressed_reader:
with BinaryReader(GzipPacked.read(reader)) as compressed_reader:
return self._process_msg(msg_id, sequence, compressed_reader, state)
# endregion

View File

@ -1,20 +1,24 @@
import logging
from datetime import timedelta
import os
import threading
from datetime import timedelta, datetime
from hashlib import md5
from io import BytesIO
from os import path
from threading import Lock
from time import sleep
from . import helpers as utils
from .crypto import rsa, CdnDecrypter
from .errors import (
RPCError, BrokenAuthKeyError,
FloodWaitError, FileMigrateError, TypeNotFoundError
RPCError, BrokenAuthKeyError, ServerError,
FloodWaitError, FileMigrateError, TypeNotFoundError,
UnauthorizedError, PhoneMigrateError, NetworkMigrateError, UserMigrateError
)
from .network import authenticator, MtProtoSender, Connection, ConnectionMode
from .tl import TLObject, Session
from .tl.all_tlobjects import LAYER
from .tl.functions import (
InitConnectionRequest, InvokeWithLayerRequest
InitConnectionRequest, InvokeWithLayerRequest, PingRequest
)
from .tl.functions.auth import (
ImportAuthorizationRequest, ExportAuthorizationRequest
@ -22,6 +26,7 @@ from .tl.functions.auth import (
from .tl.functions.help import (
GetCdnConfigRequest, GetConfigRequest
)
from .tl.functions.updates import GetStateRequest
from .tl.functions.upload import (
GetFileRequest, SaveBigFilePartRequest, SaveFilePartRequest
)
@ -52,7 +57,7 @@ class TelegramBareClient:
"""
# Current TelegramClient version
__version__ = '0.13.4'
__version__ = '0.14.2'
# TODO Make this thread-safe, all connections share the same DC
_dc_options = None
@ -62,63 +67,124 @@ class TelegramBareClient:
def __init__(self, session, api_id, api_hash,
connection_mode=ConnectionMode.TCP_FULL,
proxy=None,
process_updates=False,
timeout=timedelta(seconds=5)):
"""Initializes the Telegram client with the specified API ID and Hash.
Session must always be a Session instance, and an optional proxy
can also be specified to be used on the connection.
"""
update_workers=None,
spawn_read_thread=False,
timeout=timedelta(seconds=5),
**kwargs):
"""Refer to TelegramClient.__init__ for docs on this method"""
if not api_id or not api_hash:
raise PermissionError(
"Your API ID or Hash cannot be empty or None. "
"Refer to Telethon's README.rst for more information.")
# Determine what session object we have
if isinstance(session, str) or session is None:
session = Session.try_load_or_create_new(session)
elif not isinstance(session, Session):
raise ValueError(
'The given session must be a str or a Session instance.'
)
self.session = session
self.api_id = int(api_id)
self.api_hash = api_hash
if self.api_id < 20: # official apps must use obfuscated
self._connection_mode = ConnectionMode.TCP_OBFUSCATED
else:
self._connection_mode = connection_mode
self.proxy = proxy
self._timeout = timeout
connection_mode = ConnectionMode.TCP_OBFUSCATED
# This is the main sender, which will be used from the thread
# that calls .connect(). Every other thread will spawn a new
# temporary connection. The connection on this one is always
# kept open so Telegram can send us updates.
self._sender = MtProtoSender(self.session, Connection(
self.session.server_address, self.session.port,
mode=connection_mode, proxy=proxy, timeout=timeout
))
self._logger = logging.getLogger(__name__)
# Cache "exported" senders 'dc_id: TelegramBareClient' and
# their corresponding sessions not to recreate them all
# the time since it's a (somewhat expensive) process.
self._cached_clients = {}
# Two threads may be calling reconnect() when the connection is lost,
# we only want one to actually perform the reconnection.
self._reconnect_lock = Lock()
# Cache "exported" sessions as 'dc_id: Session' not to recreate
# them all the time since generating a new key is a relatively
# expensive operation.
self._exported_sessions = {}
# This member will process updates if enabled.
# One may change self.updates.enabled at any later point.
self.updates = UpdateState(process_updates)
self.updates = UpdateState(workers=update_workers)
# These will be set later
self._sender = None
# Used on connection - the user may modify these and reconnect
kwargs['app_version'] = kwargs.get('app_version', self.__version__)
for name, value in kwargs.items():
if not hasattr(self.session, name):
raise ValueError('Unknown named parameter', name)
setattr(self.session, name, value)
# Despite the state of the real connection, keep track of whether
# the user has explicitly called .connect() or .disconnect() here.
# This information is required by the read thread, who will be the
# one attempting to reconnect on the background *while* the user
# doesn't explicitly call .disconnect(), thus telling it to stop
# retrying. The main thread, knowing there is a background thread
# attempting reconnection as soon as it happens, will just sleep.
self._user_connected = False
# Save whether the user is authorized here (a.k.a. logged in)
self._authorized = False
# Uploaded files cache so subsequent calls are instant
self._upload_cache = {}
# Constantly read for results and updates from within the main client,
# if the user has left enabled such option.
self._spawn_read_thread = spawn_read_thread
self._recv_thread = None
# Identifier of the main thread (the one that called .connect()).
# This will be used to create new connections from any other thread,
# so that requests can be sent in parallel.
self._main_thread_ident = None
# Default PingRequest delay
self._last_ping = datetime.now()
self._ping_delay = timedelta(minutes=1)
# endregion
# region Connecting
def connect(self, exported_auth=None):
def connect(self, _exported_auth=None, _sync_updates=True, _cdn=False):
"""Connects to the Telegram servers, executing authentication if
required. Note that authenticating to the Telegram servers is
not the same as authenticating the desired user itself, which
may require a call (or several) to 'sign_in' for the first time.
If 'exported_auth' is not None, it will be used instead to
determine the authorization key for the current session.
"""
if self.is_connected():
return True
Note that the optional parameters are meant for internal use.
connection = Connection(
self.session.server_address, self.session.port,
mode=self._connection_mode, proxy=self.proxy, timeout=self._timeout
)
If '_exported_auth' is not None, it will be used instead to
determine the authorization key for the current session.
If '_sync_updates', sync_updates() will be called and a
second thread will be started if necessary. Note that this
will FAIL if the client is not connected to the user's
native data center, raising a "UserMigrateError", and
calling .disconnect() in the process.
If '_cdn' is False, methods that are not allowed on such data
centers won't be invoked.
"""
self._main_thread_ident = threading.get_ident()
try:
self._sender.connect()
if not self.session.auth_key:
# New key, we need to tell the server we're going to use
# the latest layer
try:
self.session.auth_key, self.session.time_offset = \
authenticator.do_authentication(connection)
authenticator.do_authentication(self._sender.connection)
except BrokenAuthKeyError:
return False
@ -128,34 +194,47 @@ class TelegramBareClient:
else:
init_connection = self.session.layer != LAYER
self._sender = MtProtoSender(connection, self.session)
self._sender.connect()
if init_connection:
if exported_auth is not None:
if _exported_auth is not None:
self._init_connection(ImportAuthorizationRequest(
exported_auth.id, exported_auth.bytes
_exported_auth.id, _exported_auth.bytes
))
else:
elif not _cdn:
TelegramBareClient._dc_options = \
self._init_connection(GetConfigRequest()).dc_options
elif exported_auth is not None:
elif _exported_auth is not None:
self(ImportAuthorizationRequest(
exported_auth.id, exported_auth.bytes
_exported_auth.id, _exported_auth.bytes
))
if TelegramBareClient._dc_options is None:
if TelegramBareClient._dc_options is None and not _cdn:
TelegramBareClient._dc_options = \
self(GetConfigRequest()).dc_options
# Connection was successful! Try syncing the update state
# UNLESS '_sync_updates' is False (we probably are in
# another data center and this would raise UserMigrateError)
# to also assert whether the user is logged in or not.
self._user_connected = True
if _sync_updates and not _cdn:
try:
self.sync_updates()
self._set_connected_and_authorized()
except UnauthorizedError:
self._authorized = False
return True
except TypeNotFoundError as e:
# This is fine, probably layer migration
self._logger.debug('Found invalid item, probably migrating', e)
self.disconnect()
return self.connect(exported_auth=exported_auth)
return self.connect(
_exported_auth=_exported_auth,
_sync_updates=_sync_updates,
_cdn=_cdn
)
except (RPCError, ConnectionError) as error:
# Probably errors from the previous session, ignore them
@ -166,7 +245,7 @@ class TelegramBareClient:
return False
def is_connected(self):
return self._sender is not None and self._sender.is_connected()
return self._sender.is_connected()
def _init_connection(self, query=None):
result = self(InvokeWithLayerRequest(LAYER, InitConnectionRequest(
@ -184,31 +263,54 @@ class TelegramBareClient:
return result
def disconnect(self):
"""Disconnects from the Telegram server"""
if self._sender:
self._sender.disconnect()
self._sender = None
"""Disconnects from the Telegram server
and stops all the spawned threads"""
self._user_connected = False
self._recv_thread = None
def reconnect(self, new_dc=None):
"""Disconnects and connects again (effectively reconnecting).
# This will trigger a "ConnectionResetError", for subsequent calls
# to read or send (from another thread) and usually, the background
# thread would try restarting the connection but since the
# ._recv_thread = None, it knows it doesn't have to.
self._sender.disconnect()
If 'new_dc' is not None, the current authorization key is
removed, the DC used is switched, and a new connection is made.
# TODO Shall we clear the _exported_sessions, or may be reused?
pass
def _reconnect(self, new_dc=None):
"""If 'new_dc' is not set, only a call to .connect() will be made
since it's assumed that the connection has been lost and the
library is reconnecting.
If 'new_dc' is set, the client is first disconnected from the
current data center, clears the auth key for the old DC, and
connects to the new data center.
"""
self.disconnect()
if new_dc is not None:
if new_dc is None:
# Assume we are disconnected due to some error, so connect again
with self._reconnect_lock:
# Another thread may have connected again, so check that first
if not self.is_connected():
return self.connect()
else:
return True
else:
self.disconnect()
self.session.auth_key = None # Force creating new auth_key
dc = self._get_dc(new_dc)
self.session.server_address = dc.ip_address
self.session.port = dc.port
ip = dc.ip_address
self._sender.connection.ip = self.session.server_address = ip
self._sender.connection.port = self.session.port = dc.port
self.session.save()
self.connect()
return self.connect()
# endregion
# region Working with different Data Centers
# region Working with different connections/Data Centers
def _on_read_thread(self):
return self._recv_thread is not None and \
threading.get_ident() == self._recv_thread.ident
def _get_dc(self, dc_id, ipv6=False, cdn=False):
"""Gets the Data Center (DC) associated to 'dc_id'"""
@ -235,30 +337,23 @@ class TelegramBareClient:
TelegramBareClient._dc_options = self(GetConfigRequest()).dc_options
return self._get_dc(dc_id, ipv6=ipv6, cdn=cdn)
def _get_exported_client(self, dc_id,
init_connection=False,
bypass_cache=False):
"""Gets a cached exported TelegramBareClient for the desired DC.
def _get_exported_client(self, dc_id):
"""Creates and connects a new TelegramBareClient for the desired DC.
If it's the first time retrieving the TelegramBareClient, the
current authorization is exported to the new DC so that
it can be used there, and the connection is initialized.
If after using the sender a ConnectionResetError is raised,
this method should be called again with init_connection=True
in order to perform the reconnection.
If bypass_cache is True, a new client will be exported and
it will not be cached.
If it's the first time calling the method with a given dc_id,
a new session will be first created, and its auth key generated.
Exporting/Importing the authorization will also be done so that
the auth is bound with the key.
"""
# Thanks badoualy/kotlogram on /telegram/api/DefaultTelegramClient.kt
# for clearly showing how to export the authorization! ^^
client = self._cached_clients.get(dc_id)
if client and not bypass_cache:
if init_connection:
client.reconnect()
return client
session = self._exported_sessions.get(dc_id)
if session:
export_auth = None # Already bound with the auth key
else:
# TODO Add a lock, don't allow two threads to create an auth key
# (when calling .connect() if there wasn't a previous session).
# for the same data center.
dc = self._get_dc(dc_id)
# Export the current authorization to the new DC.
@ -272,80 +367,172 @@ class TelegramBareClient:
session = Session(self.session)
session.server_address = dc.ip_address
session.port = dc.port
client = TelegramBareClient(
session, self.api_id, self.api_hash,
timeout=self._timeout
)
client.connect(exported_auth=export_auth)
self._exported_sessions[dc_id] = session
if not bypass_cache:
# Don't go through this expensive process every time.
self._cached_clients[dc_id] = client
return client
client = TelegramBareClient(
session, self.api_id, self.api_hash,
proxy=self._sender.connection.conn.proxy,
timeout=self._sender.connection.get_timeout()
)
client.connect(_exported_auth=export_auth, _sync_updates=False)
client._authorized = True # We exported the auth, so we got auth
return client
def _get_cdn_client(self, cdn_redirect):
"""Similar to ._get_exported_client, but for CDNs"""
session = self._exported_sessions.get(cdn_redirect.dc_id)
if not session:
dc = self._get_dc(cdn_redirect.dc_id, cdn=True)
session = Session(self.session)
session.server_address = dc.ip_address
session.port = dc.port
self._exported_sessions[cdn_redirect.dc_id] = session
client = TelegramBareClient(
session, self.api_id, self.api_hash,
proxy=self._sender.connection.conn.proxy,
timeout=self._sender.connection.get_timeout()
)
# This will make use of the new RSA keys for this specific CDN.
#
# This relies on the fact that TelegramBareClient._dc_options is
# static and it won't be called from this DC (it would fail).
client.connect(_cdn=True) # Avoid invoking non-CDN specific methods
client._authorized = self._authorized
return client
# endregion
# region Invoking Telegram requests
def invoke(self, request, call_receive=True, retries=5):
def invoke(self, *requests, retries=5):
"""Invokes (sends) a MTProtoRequest and returns (receives) its result.
If 'updates' is not None, all read update object will be put
in such list. Otherwise, update objects will be ignored.
If 'call_receive' is set to False, then there should be another
thread calling to 'self._sender.receive()' running or this method
will lock forever.
The invoke will be retried up to 'retries' times before raising
ValueError().
"""
if not isinstance(request, TLObject) and not request.content_related:
if not all(isinstance(x, TLObject) and
x.content_related for x in requests):
raise ValueError('You can only invoke requests, not types!')
if not self._sender:
raise ValueError('You must be connected to invoke requests!')
# Determine the sender to be used (main or a new connection)
on_main_thread = threading.get_ident() == self._main_thread_ident
if on_main_thread or self._on_read_thread():
sender = self._sender
else:
sender = self._sender.clone()
sender.connect()
# We should call receive from this thread if there's no background
# thread reading or if the server disconnected us and we're trying
# to reconnect. This is because the read thread may either be
# locked also trying to reconnect or we may be said thread already.
call_receive = not on_main_thread or self._recv_thread is None \
or self._reconnect_lock.locked()
try:
for _ in range(retries):
result = self._invoke(sender, call_receive, *requests)
if result:
return result
if retries <= 0:
raise ValueError('Number of retries reached 0.')
finally:
if sender != self._sender:
sender.disconnect() # Close temporary connections
def _invoke(self, sender, call_receive, *requests):
try:
# Ensure that we start with no previous errors (i.e. resending)
request.confirm_received.clear()
request.rpc_error = None
for x in requests:
x.confirm_received.clear()
x.rpc_error = None
sender.send(*requests)
self._sender.send(request)
if not call_receive:
# TODO This will be slightly troublesome if we allow
# switching between constant read or not on the fly.
# Must also watch out for calling .read() from two places,
# in which case a Lock would be required for .receive().
request.confirm_received.wait() # TODO Socket's timeout here?
for x in requests:
x.confirm_received.wait(
sender.connection.get_timeout()
)
else:
while not request.confirm_received.is_set():
self._sender.receive(update_state=self.updates)
while not all(x.confirm_received.is_set() for x in requests):
sender.receive(update_state=self.updates)
except TimeoutError:
pass # We will just retry
except ConnectionResetError:
if not self._authorized or self._reconnect_lock.locked():
# Only attempt reconnecting if we're authorized and not
# reconnecting already.
raise
self._logger.debug('Server disconnected us. Reconnecting and '
'resending request...')
self.reconnect()
if sender != self._sender:
# TODO Try reconnecting forever too?
sender.connect()
else:
while self._user_connected and not self._reconnect():
sleep(0.1) # Retry forever until we can send the request
finally:
if sender != self._sender:
sender.disconnect()
try:
raise next(x.rpc_error for x in requests if x.rpc_error)
except StopIteration:
if any(x.result is None for x in requests):
# "A container may only be accepted or
# rejected by the other party as a whole."
return None
elif len(requests) == 1:
return requests[0].result
else:
return [x.result for x in requests]
except (PhoneMigrateError, NetworkMigrateError,
UserMigrateError) as e:
self._logger.debug(
'DC error when invoking request, '
'attempting to reconnect at DC {}'.format(e.new_dc)
)
# TODO What happens with the background thread here?
# For normal use cases, this won't happen, because this will only
# be on the very first connection (not authorized, not running),
# but may be an issue for people who actually travel?
self._reconnect(new_dc=e.new_dc)
return self._invoke(sender, call_receive, *requests)
except ServerError as e:
# Telegram is having some issues, just retry
self._logger.debug(
'[ERROR] Telegram is having some internal issues', e
)
except FloodWaitError:
sender.disconnect()
self.disconnect()
raise
if request.rpc_error:
raise request.rpc_error
if request.result is None:
return self.invoke(
request, call_receive=call_receive, retries=(retries - 1)
)
else:
return request.result
# Let people use client(SomeRequest()) instead client.invoke(...)
__call__ = invoke
# Some really basic functionality
def is_user_authorized(self):
"""Has the user been authorized yet
(code request sent and confirmed)?"""
return self._authorized
# endregion
# region Uploading media
@ -371,10 +558,10 @@ class TelegramBareClient:
Default values for the optional parameters if left as None are:
part_size_kb = get_appropriated_part_size(file_size)
file_name = path.basename(file_path)
file_name = os.path.basename(file_path)
"""
if isinstance(file, str):
file_size = path.getsize(file)
file_size = os.path.getsize(file)
elif isinstance(file, bytes):
file_size = len(file)
else:
@ -430,7 +617,7 @@ class TelegramBareClient:
# Set a default file name if None was specified
if not file_name:
if isinstance(file, str):
file_name = path.basename(file)
file_name = os.path.basename(file)
else:
file_name = str(file_id)
@ -499,7 +686,7 @@ class TelegramBareClient:
if isinstance(result, FileCdnRedirect):
cdn_decrypter, result = \
CdnDecrypter.prepare_decrypter(
client, TelegramBareClient, result
client, self._get_cdn_client(result), result
)
except FileMigrateError as e:
@ -518,6 +705,9 @@ class TelegramBareClient:
if progress_callback:
progress_callback(f.tell(), file_size)
finally:
if client != self:
client.disconnect()
if cdn_decrypter:
try:
cdn_decrypter.client.disconnect()
@ -527,3 +717,80 @@ class TelegramBareClient:
f.close()
# endregion
# region Updates handling
def sync_updates(self):
"""Synchronizes self.updates to their initial state. Will be
called automatically on connection if self.updates.enabled = True,
otherwise it should be called manually after enabling updates.
"""
self.updates.process(self(GetStateRequest()))
def add_update_handler(self, handler):
"""Adds an update handler (a function which takes a TLObject,
an update, as its parameter) and listens for updates"""
sync = not self.updates.handlers
self.updates.handlers.append(handler)
if sync:
self.sync_updates()
def remove_update_handler(self, handler):
self.updates.handlers.remove(handler)
def list_update_handlers(self):
return self.updates.handlers[:]
# endregion
# Constant read
def _set_connected_and_authorized(self):
self._authorized = True
if self._spawn_read_thread and self._recv_thread is None:
self._recv_thread = threading.Thread(
name='ReadThread', daemon=True,
target=self._recv_thread_impl
)
self._recv_thread.start()
# By using this approach, another thread will be
# created and started upon connection to constantly read
# from the other end. Otherwise, manual calls to .receive()
# must be performed. The MtProtoSender cannot be connected,
# or an error will be thrown.
#
# This way, sending and receiving will be completely independent.
def _recv_thread_impl(self):
while self._user_connected:
try:
if datetime.now() > self._last_ping + self._ping_delay:
self._sender.send(PingRequest(
int.from_bytes(os.urandom(8), 'big', signed=True)
))
self._last_ping = datetime.now()
self._sender.receive(update_state=self.updates)
except TimeoutError:
# No problem.
pass
except ConnectionResetError:
self._logger.debug('Server disconnected us. Reconnecting...')
while self._user_connected and not self._reconnect():
sleep(0.1) # Retry forever, this is instant messaging
except Exception as error:
# Unknown exception, pass it to the main thread
self._logger.debug(
'[ERROR] Unknown error on the read thread, please report',
error
)
# If something strange happens we don't want to enter an
# infinite loop where all we do is raise an exception, so
# add a little sleep to avoid the CPU usage going mad.
sleep(0.1)
break
self._recv_thread = None
# endregion

View File

@ -1,20 +1,21 @@
import os
import threading
from datetime import datetime, timedelta
from functools import lru_cache
from mimetypes import guess_type
from threading import Thread
try:
import socks
except ImportError:
socks = None
from . import TelegramBareClient
from . import helpers as utils
from .errors import (
RPCError, UnauthorizedError, InvalidParameterError, PhoneCodeEmptyError,
PhoneMigrateError, NetworkMigrateError, UserMigrateError,
PhoneCodeExpiredError, PhoneCodeHashEmptyError, PhoneCodeInvalidError
)
from .network import ConnectionMode
from .tl import Session, TLObject
from .tl.functions import PingRequest
from .tl import TLObject
from .tl.functions.account import (
GetPasswordRequest
)
@ -29,9 +30,6 @@ from .tl.functions.messages import (
GetDialogsRequest, GetHistoryRequest, ReadHistoryRequest, SendMediaRequest,
SendMessageRequest
)
from .tl.functions.updates import (
GetStateRequest
)
from .tl.functions.users import (
GetUsersRequest
)
@ -43,6 +41,7 @@ from .tl.types import (
InputUserSelf, UserProfilePhoto, ChatPhoto, UpdateMessageID,
UpdateNewMessage, UpdateShortSentMessage
)
from .tl.types.messages import DialogsSlice
from .utils import find_user_or_chat, get_extension
@ -59,8 +58,9 @@ class TelegramClient(TelegramBareClient):
def __init__(self, session, api_id, api_hash,
connection_mode=ConnectionMode.TCP_FULL,
proxy=None,
process_updates=False,
update_workers=None,
timeout=timedelta(seconds=5),
spawn_read_thread=True,
**kwargs):
"""Initializes the Telegram client with the specified API ID and Hash.
@ -73,15 +73,21 @@ class TelegramClient(TelegramBareClient):
This will only affect how messages are sent over the network
and how much processing is required before sending them.
If 'process_updates' is set to True, incoming updates will be
processed and you must manually call 'self.updates.poll()' from
another thread to retrieve the saved update objects, or your
memory will fill with these. You may modify the value of
'self.updates.polling' at any later point.
The integer 'update_workers' represents depending on its value:
is None: Updates will *not* be stored in memory.
= 0: Another thread is responsible for calling self.updates.poll()
> 0: 'update_workers' background threads will be spawned, any
any of them will invoke all the self.updates.handlers.
Despite the value of 'process_updates', if you later call
'.add_update_handler(...)', updates will also be processed
and the update objects will be passed to the handlers you added.
If 'spawn_read_thread', a background thread will be started once
an authorized user has been logged in to Telegram to read items
(such as updates and responses) from the network as soon as they
occur, which will speed things up.
If you don't want to spawn any additional threads, pending updates
will be read and processed accordingly after invoking a request
and not immediately. This is useful if you don't care about updates
at all and have set 'update_workers=None'.
If more named arguments are provided as **kwargs, they will be
used to update the Session instance. Most common settings are:
@ -92,210 +98,55 @@ class TelegramClient(TelegramBareClient):
system_lang_code = lang_code
report_errors = True
"""
if not api_id or not api_hash:
raise PermissionError(
"Your API ID or Hash cannot be empty or None. "
"Refer to Telethon's README.rst for more information.")
# Determine what session object we have
if isinstance(session, str) or session is None:
session = Session.try_load_or_create_new(session)
elif not isinstance(session, Session):
raise ValueError(
'The given session must be a str or a Session instance.')
super().__init__(
session, api_id, api_hash,
connection_mode=connection_mode,
proxy=proxy,
process_updates=process_updates,
timeout=timeout
update_workers=update_workers,
spawn_read_thread=spawn_read_thread,
timeout=timeout,
**kwargs
)
# Used on connection - the user may modify these and reconnect
kwargs['app_version'] = kwargs.get('app_version', self.__version__)
for name, value in kwargs.items():
if hasattr(self.session, name):
setattr(self.session, name, value)
self._updates_thread = None
# Some fields to easy signing in
self._phone_code_hash = None
self._phone = None
# Uploaded files cache so subsequent calls are instant
self._upload_cache = {}
# Constantly read for results and updates from within the main client
self._recv_thread = None
# Default PingRequest delay
self._last_ping = datetime.now()
self._ping_delay = timedelta(minutes=1)
# endregion
# region Connecting
def connect(self, exported_auth=None):
"""Connects to the Telegram servers, executing authentication if
required. Note that authenticating to the Telegram servers is
not the same as authenticating the desired user itself, which
may require a call (or several) to 'sign_in' for the first time.
exported_auth is meant for internal purposes and can be ignored.
"""
if self._sender and self._sender.is_connected():
return
ok = super().connect(exported_auth=exported_auth)
# The main TelegramClient is the only one that will have
# constant_read, since it's also the only one who receives
# updates and need to be processed as soon as they occur.
#
# TODO Allow to disable this to avoid the creation of a new thread
# if the user is not going to work with updates at all? Whether to
# read constantly or not for updates needs to be known before hand,
# and further updates won't be able to be added unless allowing to
# switch the mode on the fly.
if ok:
self._recv_thread = Thread(
name='ReadThread', daemon=True,
target=self._recv_thread_impl
)
self._recv_thread.start()
if self.updates.polling:
self.sync_updates()
return ok
def disconnect(self):
"""Disconnects from the Telegram server
and stops all the spawned threads"""
if not self._sender or not self._sender.is_connected():
return
# The existing thread will close eventually, since it's
# only running while the MtProtoSender.is_connected()
self._recv_thread = None
# This will trigger a "ConnectionResetError", usually, the background
# thread would try restarting the connection but since the
# ._recv_thread = None, it knows it doesn't have to.
super().disconnect()
# Also disconnect all the cached senders
for sender in self._cached_clients.values():
sender.disconnect()
self._cached_clients.clear()
# endregion
# region Working with different connections
def create_new_connection(self, on_dc=None, timeout=timedelta(seconds=5)):
"""Creates a new connection which can be used in parallel
with the original TelegramClient. A TelegramBareClient
will be returned already connected, and the caller is
responsible to disconnect it.
If 'on_dc' is None, the new client will run on the same
data center as the current client (most common case).
If the client is meant to be used on a different data
center, the data center ID should be specified instead.
"""
if on_dc is None:
client = TelegramBareClient(
self.session, self.api_id, self.api_hash,
proxy=self.proxy, timeout=timeout
)
client.connect()
else:
client = self._get_exported_client(on_dc, bypass_cache=True)
return client
# endregion
# region Telegram requests functions
def invoke(self, request, *args, **kwargs):
"""Invokes (sends) a MTProtoRequest and returns (receives) its result.
An optional 'retries' parameter can be set.
*args will be ignored.
"""
if self._recv_thread is not None and \
threading.get_ident() == self._recv_thread.ident:
raise AssertionError('Cannot invoke requests from the ReadThread')
self.updates.check_error()
try:
# Users may call this method from within some update handler.
# If this is the case, then the thread invoking the request
# will be the one which should be reading (but is invoking the
# request) thus not being available to read it "in the background"
# and it's needed to call receive.
return super().invoke(
request, call_receive=self._recv_thread is None,
retries=kwargs.get('retries', 5)
)
except (PhoneMigrateError, NetworkMigrateError, UserMigrateError) as e:
self._logger.debug('DC error when invoking request, '
'attempting to reconnect at DC {}'
.format(e.new_dc))
self.reconnect(new_dc=e.new_dc)
return self.invoke(request)
# Let people use client(SomeRequest()) instead client.invoke(...)
__call__ = invoke
def invoke_on_dc(self, request, dc_id, reconnect=False):
"""Invokes the given request on a different DC
by making use of the exported MtProtoSenders.
If 'reconnect=True', then the a reconnection will be performed and
ConnectionResetError will be raised if it occurs a second time.
"""
try:
client = self._get_exported_client(
dc_id, init_connection=reconnect)
return client.invoke(request)
except ConnectionResetError:
if reconnect:
raise
else:
return self.invoke_on_dc(request, dc_id, reconnect=True)
# region Authorization requests
def is_user_authorized(self):
"""Has the user been authorized yet
(code request sent and confirmed)?"""
return self.session and self.get_me() is not None
def send_code_request(self, phone):
"""Sends a code request to the specified phone number"""
result = self(
SendCodeRequest(phone, self.api_id, self.api_hash))
if isinstance(phone, int):
phone = str(phone)
elif phone.startswith('+'):
phone = phone.strip('+')
result = self(SendCodeRequest(phone, self.api_id, self.api_hash))
self._phone = phone
self._phone_code_hash = result.phone_code_hash
return result
def sign_in(self, phone=None, code=None,
password=None, bot_token=None):
password=None, bot_token=None, phone_code_hash=None):
"""Completes the sign in process with the phone number + code pair.
If no phone or code is provided, then the sole password will be used.
The password should be used after a normal authorization attempt
has happened and an SessionPasswordNeededError was raised.
If you're calling .sign_in() on two completely different clients
(for example, through an API that creates a new client per phone),
you must first call .sign_in(phone) to receive the code, and then
with the result such method results, call
.sign_in(phone, code, phone_code_hash=result.phone_code_hash).
If this is done on the same client, the client will fill said values
for you.
To login as a bot, only `bot_token` should be provided.
This should equal to the bot access hash provided by
https://t.me/BotFather during your bot creation.
@ -306,64 +157,66 @@ class TelegramClient(TelegramBareClient):
if phone and not code:
return self.send_code_request(phone)
elif code:
if self._phone is None:
phone = phone or self._phone
phone_code_hash = phone_code_hash or self._phone_code_hash
if not phone:
raise ValueError(
'Please make sure to call send_code_request first.')
'Please make sure to call send_code_request first.'
)
if not phone_code_hash:
raise ValueError('You also need to provide a phone_code_hash.')
try:
if isinstance(code, int):
code = str(code)
result = self(SignInRequest(
self._phone, self._phone_code_hash, code
phone, phone_code_hash, code
))
except (PhoneCodeEmptyError, PhoneCodeExpiredError,
PhoneCodeHashEmptyError, PhoneCodeInvalidError):
return None
elif password:
salt = self(GetPasswordRequest()).current_salt
result = self(
CheckPasswordRequest(utils.get_password_hash(password, salt)))
result = self(CheckPasswordRequest(
utils.get_password_hash(password, salt)
))
elif bot_token:
result = self(ImportBotAuthorizationRequest(
flags=0, bot_auth_token=bot_token,
api_id=self.api_id, api_hash=self.api_hash))
api_id=self.api_id, api_hash=self.api_hash
))
else:
raise ValueError(
'You must provide a phone and a code the first time, '
'and a password only if an RPCError was raised before.')
'and a password only if an RPCError was raised before.'
)
self._set_connected_and_authorized()
return result.user
def sign_up(self, code, first_name, last_name=''):
"""Signs up to Telegram. Make sure you sent a code request first!"""
return self(SignUpRequest(
result = self(SignUpRequest(
phone_number=self._phone,
phone_code_hash=self._phone_code_hash,
phone_code=code,
first_name=first_name,
last_name=last_name
)).user
))
self._set_connected_and_authorized()
return result.user
def log_out(self):
"""Logs out and deletes the current session.
Returns True if everything went okay."""
# Special flag when logging out (so the ack request confirms it)
self._sender.logging_out = True
try:
self(LogOutRequest())
# The server may have already disconnected us, we still
# try to disconnect to make sure.
self.disconnect()
except (RPCError, ConnectionError):
# Something happened when logging out, restore the state back
self._sender.logging_out = False
except RPCError:
return False
self.disconnect()
self.session.delete()
self.session = None
return True
@ -386,22 +239,61 @@ class TelegramClient(TelegramBareClient):
offset_id=0,
offset_peer=InputPeerEmpty()):
"""Returns a tuple of lists ([dialogs], [entities])
with at least 'limit' items each.
with at least 'limit' items each unless all dialogs were consumed.
If `limit` is None, all dialogs will be retrieved (from the given
offset) will be retrieved.
If `limit` is 0, all dialogs will (should) retrieved.
The `entities` represent the user, chat or channel
corresponding to that dialog.
corresponding to that dialog. If it's an integer, not
all dialogs may be retrieved at once.
"""
if limit is None:
limit = float('inf')
r = self(
GetDialogsRequest(
dialogs = {} # Use Dialog.top_message as identifier to avoid dupes
messages = {} # Used later for sorting TODO also return these?
entities = {}
while len(dialogs) < limit:
r = self(GetDialogsRequest(
offset_date=offset_date,
offset_id=offset_id,
offset_peer=offset_peer,
limit=limit))
limit=0 # limit 0 often means "as much as possible"
))
if not r.dialogs:
break
for d in r.dialogs:
dialogs[d.top_message] = d
for m in r.messages:
messages[m.id] = m
# We assume users can't have the same ID as a chat
for u in r.users:
entities[u.id] = u
for c in r.chats:
entities[c.id] = c
if isinstance(r, DialogsSlice):
# Don't enter next iteration if we already got all
break
offset_date = r.messages[-1].date
offset_peer = find_user_or_chat(r.dialogs[-1].peer, entities,
entities)
offset_id = r.messages[-1].id & 4294967296 # Telegram/danog magic
# Sort by message date
no_date = datetime.fromtimestamp(0)
dialogs = sorted(
list(dialogs.values()),
key=lambda d: getattr(messages[d.top_message], 'date', no_date)
)
return (
r.dialogs,
[find_user_or_chat(d.peer, r.users, r.chats) for d in r.dialogs])
dialogs,
[find_user_or_chat(d.peer, entities, entities) for d in dialogs]
)
# endregion
@ -427,7 +319,7 @@ class TelegramClient(TelegramBareClient):
reply_to_msg_id=self._get_reply_to(reply_to)
)
result = self(request)
if isinstance(request, UpdateShortSentMessage):
if isinstance(result, UpdateShortSentMessage):
return Message(
id=result.id,
to_id=entity,
@ -540,7 +432,7 @@ class TelegramClient(TelegramBareClient):
return reply_to
if isinstance(reply_to, TLObject) and \
type(reply_to).subclass_of_id == 0x790009e3:
type(reply_to).SUBCLASS_OF_ID == 0x790009e3:
# hex(crc32(b'Message')) = 0x790009e3
return reply_to.id
@ -972,77 +864,3 @@ class TelegramClient(TelegramBareClient):
)
# endregion
# region Updates handling
def sync_updates(self):
"""Synchronizes self.updates to their initial state. Will be
called automatically on connection if self.updates.enabled = True,
otherwise it should be called manually after enabling updates.
"""
try:
self.updates.process(self(GetStateRequest()))
return True
except UnauthorizedError:
return False
def add_update_handler(self, handler):
"""Adds an update handler (a function which takes a TLObject,
an update, as its parameter) and listens for updates"""
sync = not self.updates.handlers
self.updates.handlers.append(handler)
if sync:
self.sync_updates()
def remove_update_handler(self, handler):
self.updates.handlers.remove(handler)
def list_update_handlers(self):
return self.updates.handlers[:]
# endregion
# Constant read
# By using this approach, another thread will be
# created and started upon connection to constantly read
# from the other end. Otherwise, manual calls to .receive()
# must be performed. The MtProtoSender cannot be connected,
# or an error will be thrown.
#
# This way, sending and receiving will be completely independent.
def _recv_thread_impl(self):
while self._sender and self._sender.is_connected():
try:
if datetime.now() > self._last_ping + self._ping_delay:
self._sender.send(PingRequest(
int.from_bytes(os.urandom(8), 'big', signed=True)
))
self._last_ping = datetime.now()
self._sender.receive(update_state=self.updates)
except AttributeError:
# 'NoneType' object has no attribute 'receive'.
# The only moment when this can happen is reconnection
# was triggered from another thread and the ._sender
# was set to None, so close this thread and exit by return.
self._recv_thread = None
return
except TimeoutError:
# No problem.
pass
except ConnectionResetError:
if self._recv_thread is not None:
# Do NOT attempt reconnecting unless the connection was
# finished by the user -> ._recv_thread is None
self._logger.debug('Server disconnected us. Reconnecting...')
self._recv_thread = None # Not running anymore
self.reconnect()
return
except Exception as e:
# Unknown exception, pass it to the main thread
self.updates.set_error(e)
self._recv_thread = None
return
# endregion

View File

@ -1,2 +1,5 @@
from .tlobject import TLObject
from .session import Session
from .gzip_packed import GzipPacked
from .tl_message import TLMessage
from .message_container import MessageContainer

View File

@ -0,0 +1,38 @@
import gzip
import struct
from . import TLObject
class GzipPacked(TLObject):
CONSTRUCTOR_ID = 0x3072cfa1
def __init__(self, data):
super().__init__()
self.data = data
@staticmethod
def gzip_if_smaller(request):
"""Calls request.to_bytes(), and based on a certain threshold,
optionally gzips the resulting data. If the gzipped data is
smaller than the original byte array, this is returned instead.
Note that this only applies to content related requests.
"""
data = request.to_bytes()
# TODO This threshold could be configurable
if request.content_related and len(data) > 512:
gzipped = GzipPacked(data).to_bytes()
return gzipped if len(gzipped) < len(data) else data
else:
return data
def to_bytes(self):
# TODO Maybe compress level could be an option
return struct.pack('<I', GzipPacked.CONSTRUCTOR_ID) + \
TLObject.serialize_bytes(gzip.compress(self.data))
@staticmethod
def read(reader):
reader.read_int(signed=False) # code
return gzip.decompress(reader.tgread_bytes())

View File

@ -0,0 +1,27 @@
import struct
from . import TLObject
class MessageContainer(TLObject):
CONSTRUCTOR_ID = 0x73f1f8dc
def __init__(self, messages):
super().__init__()
self.content_related = False
self.messages = messages
def to_bytes(self):
return struct.pack(
'<Ii', MessageContainer.CONSTRUCTOR_ID, len(self.messages)
) + b''.join(m.to_bytes() for m in self.messages)
@staticmethod
def iter_read(reader):
reader.read_int(signed=False) # code
size = reader.read_int()
for _ in range(size):
inner_msg_id = reader.read_long()
inner_sequence = reader.read_int()
inner_length = reader.read_int()
yield inner_msg_id, inner_sequence, inner_length

View File

@ -118,7 +118,7 @@ class Session:
# FIXME We need to import the AuthKey here or otherwise
# we get cyclic dependencies.
from ..crypto import AuthKey
if data['auth_key_data'] is not None:
if data.get('auth_key_data', None) is not None:
key = b64decode(data['auth_key_data'])
result.auth_key = AuthKey(data=key)

17
telethon/tl/tl_message.py Normal file
View File

@ -0,0 +1,17 @@
import struct
from . import TLObject, GzipPacked
class TLMessage(TLObject):
"""https://core.telegram.org/mtproto/service_messages#simple-container"""
def __init__(self, session, request):
super().__init__()
del self.content_related
self.msg_id = session.get_new_msg_id()
self.seq_no = session.generate_sequence(request.content_related)
self.request = request
def to_bytes(self):
body = GzipPacked.gzip_if_smaller(self.request)
return struct.pack('<qii', self.msg_id, self.seq_no, len(body)) + body

View File

@ -9,7 +9,6 @@ class TLObject:
self.rpc_error = None
# These should be overrode
self.constructor_id = 0
self.content_related = False # Only requests/functions/queries are
# These should not be overrode
@ -20,10 +19,13 @@ class TLObject:
"""
if indent is None:
if isinstance(obj, TLObject):
return '{{{}: {}}}'.format(
type(obj).__name__,
TLObject.pretty_format(obj.to_dict())
)
children = obj.to_dict(recursive=False)
if children:
return '{}: {}'.format(
type(obj).__name__, TLObject.pretty_format(children)
)
else:
return type(obj).__name__
if isinstance(obj, dict):
return '{{{}}}'.format(', '.join(
'{}: {}'.format(
@ -41,12 +43,13 @@ class TLObject:
else:
result = []
if isinstance(obj, TLObject):
result.append('{')
result.append(type(obj).__name__)
result.append(': ')
result.append(TLObject.pretty_format(
obj.to_dict(), indent
))
children = obj.to_dict(recursive=False)
if children:
result.append(': ')
result.append(TLObject.pretty_format(
obj.to_dict(recursive=False), indent
))
elif isinstance(obj, dict):
result.append('{\n')
@ -80,12 +83,43 @@ class TLObject:
return ''.join(result)
@staticmethod
def serialize_bytes(data):
"""Write bytes by using Telegram guidelines"""
if isinstance(data, str):
data = data.encode('utf-8')
r = []
if len(data) < 254:
padding = (len(data) + 1) % 4
if padding != 0:
padding = 4 - padding
r.append(bytes([len(data)]))
r.append(data)
else:
padding = len(data) % 4
if padding != 0:
padding = 4 - padding
r.append(bytes([
254,
len(data) % 256,
(len(data) >> 8) % 256,
(len(data) >> 16) % 256
]))
r.append(data)
r.append(bytes(padding))
return b''.join(r)
# These should be overrode
def to_dict(self):
def to_dict(self, recursive=True):
return {}
def on_send(self, writer):
pass
def to_bytes(self):
return b''
def on_response(self, reader):
pass

View File

@ -1,6 +1,7 @@
import logging
from collections import deque
from datetime import datetime
from threading import RLock, Event
from threading import RLock, Event, Thread
from .tl import types as tl
@ -9,78 +10,136 @@ class UpdateState:
"""Used to hold the current state of processed updates.
To retrieve an update, .poll() should be called.
"""
def __init__(self, polling):
self._polling = polling
WORKER_POLL_TIMEOUT = 5.0 # Avoid waiting forever on the workers
def __init__(self, workers=None):
"""
:param workers: This integer parameter has three possible cases:
workers is None: Updates will *not* be stored on self.
workers = 0: Another thread is responsible for calling self.poll()
workers > 0: 'workers' background threads will be spawned, any
any of them will invoke all the self.handlers.
"""
self._workers = workers
self._worker_threads = []
self.handlers = []
self._updates_lock = RLock()
self._updates_available = Event()
self._updates = deque()
self._logger = logging.getLogger(__name__)
# https://core.telegram.org/api/updates
self._state = tl.updates.State(0, 0, datetime.now(), 0, 0)
self._setup_workers()
def can_poll(self):
"""Returns True if a call to .poll() won't lock"""
return self._updates_available.is_set()
def poll(self):
"""Polls an update or blocks until an update object is available"""
if not self._polling:
raise ValueError('Updates are not being polled hence not saved.')
def poll(self, timeout=None):
"""Polls an update or blocks until an update object is available.
If 'timeout is not None', it should be a floating point value,
and the method will 'return None' if waiting times out.
"""
if not self._updates_available.wait(timeout=timeout):
return
self._updates_available.wait()
with self._updates_lock:
if not self._updates_available.is_set():
return
update = self._updates.popleft()
if not self._updates:
self._updates_available.clear()
if isinstance(update, Exception):
raise update # Some error was set through .set_error()
raise update # Some error was set through (surely StopIteration)
return update
def get_polling(self):
return self._polling
def get_workers(self):
return self._workers
def set_polling(self, polling):
self._polling = polling
if not polling:
with self._updates_lock:
self._updates.clear()
polling = property(fget=get_polling, fset=set_polling)
def set_error(self, error):
"""Sets an error, so that the next call to .poll() will raise it.
Can be (and is) used to pass exceptions between threads.
def set_workers(self, n):
"""Changes the number of workers running.
If 'n is None', clears all pending updates from memory.
"""
with self._updates_lock:
# Insert at the beginning so the very next poll causes an error
# TODO Should this reset the pts and such?
self._updates.insert(0, error)
self._updates_available.set()
self._stop_workers()
self._workers = n
if n is None:
self._updates.clear()
else:
self._setup_workers()
def check_error(self):
with self._updates_lock:
if self._updates and isinstance(self._updates[0], Exception):
raise self._updates.pop()
workers = property(fget=get_workers, fset=set_workers)
def _stop_workers(self):
"""Raises "StopIterationException" on the worker threads to stop them,
and also clears all of them off the list
"""
if self._workers:
with self._updates_lock:
# Insert at the beginning so the very next poll causes an error
# on all the worker threads
# TODO Should this reset the pts and such?
for _ in range(self._workers):
self._updates.appendleft(StopIteration())
self._updates_available.set()
for t in self._worker_threads:
t.join()
self._worker_threads.clear()
def _setup_workers(self):
if self._worker_threads or not self._workers:
# There already are workers, or workers is None or 0. Do nothing.
return
for i in range(self._workers):
thread = Thread(
target=UpdateState._worker_loop,
name='UpdateWorker{}'.format(i),
daemon=True,
args=(self, i)
)
self._worker_threads.append(thread)
thread.start()
def _worker_loop(self, wid):
while True:
try:
update = self.poll(timeout=UpdateState.WORKER_POLL_TIMEOUT)
# TODO Maybe people can add different handlers per update type
if update:
for handler in self.handlers:
handler(update)
except StopIteration:
break
except Exception as e:
# We don't want to crash a worker thread due to any reason
self._logger.debug(
'[ERROR] Unhandled exception on worker {}'.format(wid), e
)
def process(self, update):
"""Processes an update object. This method is normally called by
the library itself.
"""
if not self._polling and not self.handlers:
return
if self._workers is None:
return # No processing needs to be done if nobody's working
with self._updates_lock:
if isinstance(update, tl.updates.State):
self._state = update
elif not hasattr(update, 'pts') or update.pts > self._state.pts:
self._state.pts = getattr(update, 'pts', self._state.pts)
return # Nothing else to be done
if self._polling:
self._updates.append(update)
self._updates_available.set()
pts = getattr(update, 'pts', self._state.pts)
if hasattr(update, 'pts') and pts <= self._state.pts:
return # We already handled this update
for handler in self.handlers:
handler(update)
self._state.pts = pts
self._updates.append(update)
self._updates_available.set()

View File

@ -10,8 +10,17 @@ from .tl.types import (
ChatPhoto, InputPeerChannel, InputPeerChat, InputPeerUser, InputPeerEmpty,
MessageMediaDocument, MessageMediaPhoto, PeerChannel, InputChannel,
UserEmpty, InputUser, InputUserEmpty, InputUserSelf, InputPeerSelf,
PeerChat, PeerUser, User, UserFull, UserProfilePhoto, Document
)
PeerChat, PeerUser, User, UserFull, UserProfilePhoto, Document,
MessageMediaContact, MessageMediaEmpty, MessageMediaGame, MessageMediaGeo,
MessageMediaUnsupported, MessageMediaVenue, InputMediaContact,
InputMediaDocument, InputMediaEmpty, InputMediaGame,
InputMediaGeoPoint, InputMediaPhoto, InputMediaVenue, InputDocument,
DocumentEmpty, InputDocumentEmpty, Message, GeoPoint, InputGeoPoint,
GeoPointEmpty, InputGeoPointEmpty, Photo, InputPhoto, PhotoEmpty,
InputPhotoEmpty, FileLocation, ChatPhotoEmpty, UserProfilePhotoEmpty,
FileLocationUnavailable, InputMediaUploadedDocument,
InputMediaUploadedPhoto,
DocumentAttributeFilename)
def get_display_name(entity):
@ -65,13 +74,10 @@ def _raise_cast_fail(entity, target):
def get_input_peer(entity):
"""Gets the input peer for the given "entity" (user, chat or channel).
A ValueError is raised if the given entity isn't a supported type."""
if entity is None:
return None
if not isinstance(entity, TLObject):
_raise_cast_fail(entity, 'InputPeer')
if type(entity).subclass_of_id == 0xc91c90b6: # crc32(b'InputPeer')
if type(entity).SUBCLASS_OF_ID == 0xc91c90b6: # crc32(b'InputPeer')
return entity
if isinstance(entity, User):
@ -109,13 +115,10 @@ def get_input_peer(entity):
def get_input_channel(entity):
"""Similar to get_input_peer, but for InputChannel's alone"""
if entity is None:
return None
if not isinstance(entity, TLObject):
_raise_cast_fail(entity, 'InputChannel')
if type(entity).subclass_of_id == 0x40f202fd: # crc32(b'InputChannel')
if type(entity).SUBCLASS_OF_ID == 0x40f202fd: # crc32(b'InputChannel')
return entity
if isinstance(entity, Channel) or isinstance(entity, ChannelForbidden):
@ -129,13 +132,10 @@ def get_input_channel(entity):
def get_input_user(entity):
"""Similar to get_input_peer, but for InputUser's alone"""
if entity is None:
return None
if not isinstance(entity, TLObject):
_raise_cast_fail(entity, 'InputUser')
if type(entity).subclass_of_id == 0xe669bf46: # crc32(b'InputUser')
if type(entity).SUBCLASS_OF_ID == 0xe669bf46: # crc32(b'InputUser')
return entity
if isinstance(entity, User):
@ -156,27 +156,169 @@ def get_input_user(entity):
_raise_cast_fail(entity, 'InputUser')
def get_input_document(document):
"""Similar to get_input_peer, but for documents"""
if not isinstance(document, TLObject):
_raise_cast_fail(document, 'InputDocument')
if type(document).SUBCLASS_OF_ID == 0xf33fdb68: # crc32(b'InputDocument')
return document
if isinstance(document, Document):
return InputDocument(id=document.id, access_hash=document.access_hash)
if isinstance(document, DocumentEmpty):
return InputDocumentEmpty()
if isinstance(document, MessageMediaDocument):
return get_input_document(document.document)
if isinstance(document, Message):
return get_input_document(document.media)
_raise_cast_fail(document, 'InputDocument')
def get_input_photo(photo):
"""Similar to get_input_peer, but for documents"""
if not isinstance(photo, TLObject):
_raise_cast_fail(photo, 'InputPhoto')
if type(photo).SUBCLASS_OF_ID == 0x846363e0: # crc32(b'InputPhoto')
return photo
if isinstance(photo, Photo):
return InputPhoto(id=photo.id, access_hash=photo.access_hash)
if isinstance(photo, PhotoEmpty):
return InputPhotoEmpty()
_raise_cast_fail(photo, 'InputPhoto')
def get_input_geo(geo):
"""Similar to get_input_peer, but for geo points"""
if not isinstance(geo, TLObject):
_raise_cast_fail(geo, 'InputGeoPoint')
if type(geo).SUBCLASS_OF_ID == 0x430d225: # crc32(b'InputGeoPoint')
return geo
if isinstance(geo, GeoPoint):
return InputGeoPoint(lat=geo.lat, long=geo.long)
if isinstance(geo, GeoPointEmpty):
return InputGeoPointEmpty()
if isinstance(geo, MessageMediaGeo):
return get_input_geo(geo.geo)
if isinstance(geo, Message):
return get_input_geo(geo.media)
_raise_cast_fail(geo, 'InputGeoPoint')
def get_input_media(media, user_caption=None, is_photo=False):
"""Similar to get_input_peer, but for media.
If the media is a file location and is_photo is known to be True,
it will be treated as an InputMediaUploadedPhoto.
"""
if not isinstance(media, TLObject):
_raise_cast_fail(media, 'InputMedia')
if type(media).SUBCLASS_OF_ID == 0xfaf846f4: # crc32(b'InputMedia')
return media
if isinstance(media, MessageMediaPhoto):
return InputMediaPhoto(
id=get_input_photo(media.photo),
caption=media.caption if user_caption is None else user_caption,
ttl_seconds=media.ttl_seconds
)
if isinstance(media, MessageMediaDocument):
return InputMediaDocument(
id=get_input_document(media.document),
caption=media.caption if user_caption is None else user_caption,
ttl_seconds=media.ttl_seconds
)
if isinstance(media, FileLocation):
if is_photo:
return InputMediaUploadedPhoto(
file=media,
caption=user_caption or ''
)
else:
return InputMediaUploadedDocument(
file=media,
mime_type='application/octet-stream', # unknown, assume bytes
attributes=[DocumentAttributeFilename('unnamed')],
caption=user_caption or ''
)
if isinstance(media, MessageMediaGame):
return InputMediaGame(id=media.game.id)
if isinstance(media, ChatPhoto) or isinstance(media, UserProfilePhoto):
if isinstance(media.photo_big, FileLocationUnavailable):
return get_input_media(media.photo_small, is_photo=True)
else:
return get_input_media(media.photo_big, is_photo=True)
if isinstance(media, MessageMediaContact):
return InputMediaContact(
phone_number=media.phone_number,
first_name=media.first_name,
last_name=media.last_name
)
if isinstance(media, MessageMediaGeo):
return InputMediaGeoPoint(geo_point=get_input_geo(media.geo))
if isinstance(media, MessageMediaVenue):
return InputMediaVenue(
geo_point=get_input_geo(media.geo),
title=media.title,
address=media.address,
provider=media.provider,
venue_id=media.venue_id
)
if any(isinstance(media, t) for t in (
MessageMediaEmpty, MessageMediaUnsupported,
FileLocationUnavailable, ChatPhotoEmpty,
UserProfilePhotoEmpty)):
return InputMediaEmpty()
if isinstance(media, Message):
return get_input_media(media.media)
_raise_cast_fail(media, 'InputMedia')
def find_user_or_chat(peer, users, chats):
"""Finds the corresponding user or chat given a peer.
Returns None if it was not found"""
try:
if isinstance(peer, PeerUser):
return next(u for u in users if u.id == peer.user_id)
elif isinstance(peer, PeerChat):
return next(c for c in chats if c.id == peer.chat_id)
if isinstance(peer, PeerUser):
peer, where = peer.user_id, users
else:
where = chats
if isinstance(peer, PeerChat):
peer = peer.chat_id
elif isinstance(peer, PeerChannel):
return next(c for c in chats if c.id == peer.channel_id)
except StopIteration: return
peer = peer.channel_id
if isinstance(peer, int):
try: return next(u for u in users if u.id == peer)
except StopIteration: pass
try: return next(c for c in chats if c.id == peer)
except StopIteration: pass
if isinstance(where, dict):
return where.get(peer)
else:
try:
return next(x for x in where if x.id == peer)
except StopIteration:
pass
def get_appropriated_part_size(file_size):

View File

@ -96,6 +96,17 @@ class TLObject:
result=match.group(3),
is_function=is_function)
def class_name(self):
"""Gets the class name following the Python style guidelines"""
# Courtesy of http://stackoverflow.com/a/31531797/4759433
result = re.sub(r'_([a-z])', lambda m: m.group(1).upper(), self.name)
result = result[:1].upper() + result[1:].replace('_', '')
# If it's a function, let it end with "Request" to identify them
if self.is_function:
result += 'Request'
return result
def sorted_args(self):
"""Returns the arguments properly sorted and ready to plug-in
into a Python's method header (i.e., flags and those which
@ -197,8 +208,8 @@ class TLArg:
else:
self.flag_indicator = False
self.is_generic = arg_type.startswith('!')
self.type = arg_type.lstrip(
'!') # Strip the exclamation mark always to have only the name
# Strip the exclamation mark always to have only the name
self.type = arg_type.lstrip('!')
# The type may be a flag (flags.IDX?REAL_TYPE)
# Note that 'flags' is NOT the flags name; this is determined by a previous argument
@ -233,6 +244,24 @@ class TLArg:
self.generic_definition = generic_definition
def type_hint(self):
result = {
'int': 'int',
'long': 'int',
'int128': 'int',
'int256': 'int',
'string': 'str',
'date': 'datetime.datetime | None', # None date = 0 timestamp
'bytes': 'bytes',
'true': 'bool',
}.get(self.type, 'TLObject')
if self.is_vector:
result = 'list[{}]'.format(result)
if self.is_flag and self.type != 'date':
result += ' | None'
return result
def __str__(self):
# Find the real type representation by updating it as required
real_type = self.type

View File

@ -32,16 +32,16 @@
/// Authorization key creation
///////////////////////////////
resPQ#05162463 nonce:int128 server_nonce:int128 pq:string server_public_key_fingerprints:Vector<long> = ResPQ;
resPQ#05162463 nonce:int128 server_nonce:int128 pq:bytes server_public_key_fingerprints:Vector<long> = ResPQ;
p_q_inner_data#83c95aec pq:string p:string q:string nonce:int128 server_nonce:int128 new_nonce:int256 = P_Q_inner_data;
p_q_inner_data#83c95aec pq:bytes p:bytes q:bytes nonce:int128 server_nonce:int128 new_nonce:int256 = P_Q_inner_data;
server_DH_params_fail#79cb045d nonce:int128 server_nonce:int128 new_nonce_hash:int128 = Server_DH_Params;
server_DH_params_ok#d0e8075c nonce:int128 server_nonce:int128 encrypted_answer:string = Server_DH_Params;
server_DH_params_ok#d0e8075c nonce:int128 server_nonce:int128 encrypted_answer:bytes = Server_DH_Params;
server_DH_inner_data#b5890dba nonce:int128 server_nonce:int128 g:int dh_prime:string g_a:string server_time:int = Server_DH_inner_data;
server_DH_inner_data#b5890dba nonce:int128 server_nonce:int128 g:int dh_prime:bytes g_a:bytes server_time:int = Server_DH_inner_data;
client_DH_inner_data#6643b654 nonce:int128 server_nonce:int128 retry_id:long g_b:string = Client_DH_Inner_Data;
client_DH_inner_data#6643b654 nonce:int128 server_nonce:int128 retry_id:long g_b:bytes = Client_DH_Inner_Data;
dh_gen_ok#3bcbf734 nonce:int128 server_nonce:int128 new_nonce_hash1:int128 = Set_client_DH_params_answer;
dh_gen_retry#46dc1fb9 nonce:int128 server_nonce:int128 new_nonce_hash2:int128 = Set_client_DH_params_answer;
@ -55,9 +55,9 @@ destroy_auth_key_fail#ea109b13 = DestroyAuthKeyRes;
req_pq#60469778 nonce:int128 = ResPQ;
req_DH_params#d712e4be nonce:int128 server_nonce:int128 p:string q:string public_key_fingerprint:long encrypted_data:string = Server_DH_Params;
req_DH_params#d712e4be nonce:int128 server_nonce:int128 p:bytes q:bytes public_key_fingerprint:long encrypted_data:bytes = Server_DH_Params;
set_client_DH_params#f5045f1f nonce:int128 server_nonce:int128 encrypted_data:string = Set_client_DH_params_answer;
set_client_DH_params#f5045f1f nonce:int128 server_nonce:int128 encrypted_data:bytes = Set_client_DH_params_answer;
destroy_auth_key#d1435160 = DestroyAuthKeyRes;

View File

@ -1,6 +1,7 @@
import os
import re
import shutil
import struct
from zlib import crc32
from collections import defaultdict
@ -107,8 +108,7 @@ class TLGenerator:
if tlobject.namespace:
builder.write('.' + tlobject.namespace)
builder.writeln('.{},'.format(
TLGenerator.get_class_name(tlobject)))
builder.writeln('.{},'.format(tlobject.class_name()))
builder.current_indent -= 1
builder.writeln('}')
@ -137,13 +137,29 @@ class TLGenerator:
x for x in namespace_tlobjects.keys() if x
)))
# Import 'get_input_*' utils
# TODO Support them on types too
if 'functions' in out_dir:
builder.writeln(
'from {}.utils import get_input_peer, '
'get_input_channel, get_input_user, '
'get_input_media'.format('.' * depth)
)
# Import 'os' for those needing access to 'os.urandom()'
# Currently only 'random_id' needs 'os' to be imported,
# for all those TLObjects with arg.can_be_inferred.
builder.writeln('import os')
# Import struct for the .to_bytes(self) serialization
builder.writeln('import struct')
# Generate the class for every TLObject
for t in sorted(tlobjects, key=lambda x: x.name):
TLGenerator._write_source_code(
t, builder, depth, type_constructors
)
while builder.current_indent != 0:
builder.end_block()
builder.current_indent = 0
@staticmethod
def _write_source_code(tlobject, builder, depth, type_constructors):
@ -154,35 +170,15 @@ class TLGenerator:
the Type: [Constructors] must be given for proper
importing and documentation strings.
"""
if tlobject.is_function:
util_imports = set()
for a in tlobject.args:
# We can automatically convert some "full" types to
# "input only" (like User -> InputPeerUser, etc.)
if a.type == 'InputPeer':
util_imports.add('get_input_peer')
elif a.type == 'InputChannel':
util_imports.add('get_input_channel')
elif a.type == 'InputUser':
util_imports.add('get_input_user')
if util_imports:
builder.writeln('from {}.utils import {}'.format(
'.' * depth, ', '.join(util_imports)))
if any(a for a in tlobject.args if a.can_be_inferred):
# Currently only 'random_id' needs 'os' to be imported
builder.writeln('import os')
builder.writeln()
builder.writeln()
builder.writeln('class {}(TLObject):'.format(
TLGenerator.get_class_name(tlobject)))
builder.writeln('class {}(TLObject):'.format(tlobject.class_name()))
# Class-level variable to store its Telegram's constructor ID
builder.writeln('constructor_id = {}'.format(hex(tlobject.id)))
builder.writeln('subclass_of_id = {}'.format(
hex(crc32(tlobject.result.encode('ascii')))))
builder.writeln('CONSTRUCTOR_ID = {}'.format(hex(tlobject.id)))
builder.writeln('SUBCLASS_OF_ID = {}'.format(
hex(crc32(tlobject.result.encode('ascii'))))
)
builder.writeln()
# Flag arguments must go last
@ -221,17 +217,10 @@ class TLGenerator:
builder.writeln('"""')
for arg in args:
if not arg.flag_indicator:
builder.write(
':param {}: Telegram type: "{}".'
.format(arg.name, arg.type)
)
if arg.is_vector:
builder.write(' Must be a list.'.format(arg.name))
if arg.is_generic:
builder.write(' Must be another TLObject request.')
builder.writeln()
builder.writeln(':param {} {}:'.format(
arg.type_hint(), arg.name
))
builder.current_indent -= 1 # It will auto-indent (':')
# We also want to know what type this request returns
# or to which type this constructor belongs to
@ -246,12 +235,11 @@ class TLGenerator:
builder.writeln('This type has no constructors.')
elif len(constructors) == 1:
builder.writeln('Instance of {}.'.format(
TLGenerator.get_class_name(constructors[0])
constructors[0].class_name()
))
else:
builder.writeln('Instance of either {}.'.format(
', '.join(TLGenerator.get_class_name(c)
for c in constructors)
', '.join(c.class_name() for c in constructors)
))
builder.writeln('"""')
@ -274,56 +262,77 @@ class TLGenerator:
builder.end_block()
# Write the to_dict(self) method
builder.writeln('def to_dict(self, recursive=True):')
if args:
builder.writeln('def to_dict(self):')
builder.writeln('return {')
builder.current_indent += 1
base_types = ('string', 'bytes', 'int', 'long', 'int128',
'int256', 'double', 'Bool', 'true', 'date')
for arg in args:
builder.write("'{}': ".format(arg.name))
if arg.type in base_types:
if arg.is_vector:
builder.write(
'[] if self.{0} is None else self.{0}[:]'
.format(arg.name)
)
else:
builder.write('self.{}'.format(arg.name))
else:
if arg.is_vector:
builder.write(
'[] if self.{0} is None else [None '
'if x is None else x.to_dict() for x in self.{0}]'
.format(arg.name)
)
else:
builder.write(
'None if self.{0} is None else self.{0}.to_dict()'
.format(arg.name)
)
builder.writeln(',')
builder.current_indent -= 1
builder.writeln("}")
else:
builder.writeln('@staticmethod')
builder.writeln('def to_dict():')
builder.writeln('return {}')
builder.write('return {')
builder.current_indent += 1
base_types = ('string', 'bytes', 'int', 'long', 'int128',
'int256', 'double', 'Bool', 'true', 'date')
for arg in args:
builder.write("'{}': ".format(arg.name))
if arg.type in base_types:
if arg.is_vector:
builder.write('[] if self.{0} is None else self.{0}[:]'
.format(arg.name))
else:
builder.write('self.{}'.format(arg.name))
else:
if arg.is_vector:
builder.write(
'([] if self.{0} is None else [None'
' if x is None else x.to_dict() for x in self.{0}]'
') if recursive else self.{0}'.format(arg.name)
)
else:
builder.write(
'(None if self.{0} is None else self.{0}.to_dict())'
' if recursive else self.{0}'.format(arg.name)
)
builder.writeln(',')
builder.current_indent -= 1
builder.writeln("}")
builder.end_block()
# Write the on_send(self, writer) function
builder.writeln('def on_send(self, writer):')
builder.writeln(
'writer.write_int({}.constructor_id, signed=False)'
.format(TLGenerator.get_class_name(tlobject)))
# Write the .to_bytes() function
builder.writeln('def to_bytes(self):')
# Some objects require more than one flag parameter to be set
# at the same time. In this case, add an assertion.
repeated_args = defaultdict(list)
for arg in tlobject.args:
if arg.is_flag:
repeated_args[arg.flag_index].append(arg)
for ra in repeated_args.values():
if len(ra) > 1:
cnd1 = ('self.{} is None'.format(a.name) for a in ra)
cnd2 = ('self.{} is not None'.format(a.name) for a in ra)
builder.writeln(
"assert ({}) or ({}), '{} parameters must all "
"be None or neither be None'".format(
' and '.join(cnd1), ' and '.join(cnd2),
', '.join(a.name for a in ra)
)
)
builder.writeln("return b''.join((")
builder.current_indent += 1
# First constructor code, we already know its bytes
builder.writeln('{},'.format(repr(struct.pack('<I', tlobject.id))))
for arg in tlobject.args:
TLGenerator.write_onsend_code(builder, arg,
tlobject.args)
if TLGenerator.write_to_bytes(builder, arg, tlobject.args):
builder.writeln(',')
builder.current_indent -= 1
builder.writeln('))')
builder.end_block()
# Write the empty() function, which returns an "empty"
@ -331,8 +340,8 @@ class TLGenerator:
builder.writeln('@staticmethod')
builder.writeln('def empty():')
builder.writeln('return {}({})'.format(
TLGenerator.get_class_name(tlobject), ', '.join(
'None' for _ in range(len(args)))))
tlobject.class_name(), ', '.join('None' for _ in range(len(args)))
))
builder.end_block()
# Write the on_response(self, reader) function
@ -345,18 +354,15 @@ class TLGenerator:
if tlobject.args:
for arg in tlobject.args:
TLGenerator.write_onresponse_code(
builder, arg, tlobject.args)
builder, arg, tlobject.args
)
else:
# If there were no arguments, we still need an
# on_response method, and hence "pass" if empty
builder.writeln('pass')
builder.end_block()
# Write the __repr__(self) and __str__(self) functions
builder.writeln('def __repr__(self):')
builder.writeln("return '{}'".format(repr(tlobject)))
builder.end_block()
# Write the __str__(self) and stringify(self) functions
builder.writeln('def __str__(self):')
builder.writeln('return TLObject.pretty_format(self)')
builder.end_block()
@ -398,6 +404,8 @@ class TLGenerator:
TLGenerator.write_get_input(builder, arg, 'get_input_channel')
elif arg.type == 'InputUser' and tlobject.is_function:
TLGenerator.write_get_input(builder, arg, 'get_input_user')
elif arg.type == 'InputMedia' and tlobject.is_function:
TLGenerator.write_get_input(builder, arg, 'get_input_media')
else:
builder.writeln('self.{0} = {0}'.format(arg.name))
@ -408,29 +416,14 @@ class TLGenerator:
a parameter upon creating the request. Returns False otherwise
"""
if arg.is_vector:
builder.writeln(
'self.{0} = [{1}(_x) for _x in {0}]'
.format(arg.name, get_input_code)
)
pass
builder.write('self.{0} = [{1}(_x) for _x in {0}]'
.format(arg.name, get_input_code))
else:
builder.writeln(
'self.{0} = {1}({0})'.format(arg.name, get_input_code)
)
@staticmethod
def get_class_name(tlobject):
"""Gets the class name following the Python style guidelines"""
# Courtesy of http://stackoverflow.com/a/31531797/4759433
result = re.sub(r'_([a-z])', lambda m: m.group(1).upper(),
tlobject.name)
result = result[:1].upper() + result[1:].replace(
'_', '') # Replace again to fully ensure!
# If it's a function, let it end with "Request" to identify them
if tlobject.is_function:
result += 'Request'
return result
builder.write('self.{0} = {1}({0})'
.format(arg.name, get_input_code))
builder.writeln(
' if {} else None'.format(arg.name) if arg.is_flag else ''
)
@staticmethod
def get_file_name(tlobject, add_extension=False):
@ -445,18 +438,17 @@ class TLGenerator:
return result
@staticmethod
def write_onsend_code(builder, arg, args, name=None):
def write_to_bytes(builder, arg, args, name=None):
"""
Writes the write code for the given argument
Writes the .to_bytes() code for the given argument
:param builder: The source code builder
:param arg: The argument to write
:param args: All the other arguments in TLObject same on_send.
:param args: All the other arguments in TLObject same to_bytes.
This is required to determine the flags value
:param name: The name of the argument. Defaults to "self.argname"
This argument is an option because it's required when
writing Vectors<>
"""
if arg.generic_definition:
return # Do nothing, this only specifies a later type
@ -470,73 +462,91 @@ class TLGenerator:
if arg.is_flag:
if arg.type == 'true':
return # Exit, since True type is never written
elif arg.is_vector:
# Vector flags are special since they consist of 3 values,
# so we need an extra join here. Note that empty vector flags
# should NOT be sent either!
builder.write("b'' if not {} else b''.join((".format(name))
else:
builder.writeln('if {}:'.format(name))
builder.write("b'' if not {} else (".format(name))
if arg.is_vector:
if arg.use_vector_id:
builder.writeln('writer.write_int(0x1cb5c415, signed=False)')
# vector code, unsigned 0x1cb5c415 as little endian
builder.write(r"b'\x15\xc4\xb5\x1c',")
builder.write("struct.pack('<i', len({})),".format(name))
# Cannot unpack the values for the outer tuple through *[(
# since that's a Python >3.5 feature, so add another join.
builder.write("b''.join(")
builder.writeln('writer.write_int(len({}))'.format(name))
builder.writeln('for _x in {}:'.format(name))
# Temporary disable .is_vector, not to enter this if again
arg.is_vector = False
TLGenerator.write_onsend_code(builder, arg, args, name='_x')
# Also disable .is_flag since it's not needed per element
old_flag = arg.is_flag
arg.is_vector = arg.is_flag = False
TLGenerator.write_to_bytes(builder, arg, args, name='x')
arg.is_vector = True
arg.is_flag = old_flag
builder.write(' for x in {})'.format(name))
elif arg.flag_indicator:
# Calculate the flags with those items which are not None
builder.writeln('flags = 0')
for flag in args:
if flag.is_flag:
builder.writeln('flags |= (1 << {}) if {} else 0'.format(
flag.flag_index, 'self.{}'.format(flag.name)))
builder.writeln('writer.write_int(flags)')
builder.writeln()
builder.write("struct.pack('<I', {})".format(
' | '.join('({} if {} else 0)'.format(
1 << flag.flag_index, 'self.{}'.format(flag.name)
) for flag in args if flag.is_flag)
))
elif 'int' == arg.type:
builder.writeln('writer.write_int({})'.format(name))
# struct.pack is around 4 times faster than int.to_bytes
builder.write("struct.pack('<i', {})".format(name))
elif 'long' == arg.type:
builder.writeln('writer.write_long({})'.format(name))
builder.write("struct.pack('<q', {})".format(name))
elif 'int128' == arg.type:
builder.writeln('writer.write_large_int({}, bits=128)'.format(
name))
builder.write("{}.to_bytes(16, 'little', signed=True)".format(name))
elif 'int256' == arg.type:
builder.writeln('writer.write_large_int({}, bits=256)'.format(
name))
builder.write("{}.to_bytes(32, 'little', signed=True)".format(name))
elif 'double' == arg.type:
builder.writeln('writer.write_double({})'.format(name))
builder.write("struct.pack('<d', {})".format(name))
elif 'string' == arg.type:
builder.writeln('writer.tgwrite_string({})'.format(name))
builder.write('TLObject.serialize_bytes({})'.format(name))
elif 'Bool' == arg.type:
builder.writeln('writer.tgwrite_bool({})'.format(name))
# 0x997275b5 if boolean else 0xbc799737
builder.write(
r"b'\xb5ur\x99' if {} else b'7\x97y\xbc'".format(name)
)
elif 'true' == arg.type:
pass # These are actually NOT written! Only used for flags
elif 'bytes' == arg.type:
builder.writeln('writer.tgwrite_bytes({})'.format(name))
builder.write('TLObject.serialize_bytes({})'.format(name))
elif 'date' == arg.type: # Custom format
builder.writeln('writer.tgwrite_date({})'.format(name))
# 0 if datetime is None else int(datetime.timestamp())
builder.write(
r"b'\0\0\0\0' if {0} is None else "
r"struct.pack('<I', int({0}.timestamp()))".format(name)
)
else:
# Else it may be a custom type
builder.writeln('{}.on_send(writer)'.format(name))
# End vector and flag blocks if required (if we opened them before)
if arg.is_vector:
builder.end_block()
builder.write('{}.to_bytes()'.format(name))
if arg.is_flag:
builder.end_block()
builder.write(')')
if arg.is_vector:
builder.write(')') # We were using a tuple
return True # Something was written
@staticmethod
def write_onresponse_code(builder, arg, args, name=None):
@ -562,8 +572,8 @@ class TLGenerator:
was_flag = False
if arg.is_flag:
was_flag = True
builder.writeln('if (flags & (1 << {})) != 0:'.format(
arg.flag_index
builder.writeln('if flags & {}:'.format(
1 << arg.flag_index
))
# Temporary disable .is_flag not to enter this if
# again when calling the method recursively

View File

@ -1,90 +1,61 @@
import os
import unittest
from telethon.extensions import BinaryReader, BinaryWriter
from telethon.tl import TLObject
from telethon.extensions import BinaryReader
class UtilsTests(unittest.TestCase):
@staticmethod
def test_binary_writer_reader():
# Test that we can write and read properly
with BinaryWriter() as writer:
writer.write_byte(1)
writer.write_int(5)
writer.write_long(13)
writer.write_float(17.0)
writer.write_double(25.0)
writer.write(bytes([26, 27, 28, 29, 30, 31, 32]))
writer.write_large_int(2**127, 128, signed=False)
data = writer.get_bytes()
expected = b'\x01\x05\x00\x00\x00\r\x00\x00\x00\x00\x00\x00\x00\x00\x00\x88A\x00\x00\x00\x00\x00\x00' \
b'9@\x1a\x1b\x1c\x1d\x1e\x1f \x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x80'
assert data == expected, 'Retrieved data does not match the expected value'
# Test that we can read properly
data = b'\x01\x05\x00\x00\x00\r\x00\x00\x00\x00\x00\x00\x00\x00\x00' \
b'\x88A\x00\x00\x00\x00\x00\x009@\x1a\x1b\x1c\x1d\x1e\x1f ' \
b'\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00' \
b'\x00\x80'
with BinaryReader(data) as reader:
value = reader.read_byte()
assert value == 1, 'Example byte should be 1 but is {}'.format(
value)
assert value == 1, 'Example byte should be 1 but is {}'.format(value)
value = reader.read_int()
assert value == 5, 'Example integer should be 5 but is {}'.format(
value)
assert value == 5, 'Example integer should be 5 but is {}'.format(value)
value = reader.read_long()
assert value == 13, 'Example long integer should be 13 but is {}'.format(
value)
assert value == 13, 'Example long integer should be 13 but is {}'.format(value)
value = reader.read_float()
assert value == 17.0, 'Example float should be 17.0 but is {}'.format(
value)
assert value == 17.0, 'Example float should be 17.0 but is {}'.format(value)
value = reader.read_double()
assert value == 25.0, 'Example double should be 25.0 but is {}'.format(
value)
assert value == 25.0, 'Example double should be 25.0 but is {}'.format(value)
value = reader.read(7)
assert value == bytes([26, 27, 28, 29, 30, 31, 32]), 'Example bytes should be {} but is {}' \
.format(bytes([26, 27, 28, 29, 30, 31, 32]), value)
value = reader.read_large_int(128, signed=False)
assert value == 2**127, 'Example large integer should be {} but is {}'.format(
2**127, value)
# Test Telegram that types are written right
with BinaryWriter() as writer:
writer.write_int(0x60469778)
buffer = writer.get_bytes()
valid = b'\x78\x97\x46\x60' # Tested written bytes using C#'s MemoryStream
assert buffer == valid, 'Written type should be {} but is {}'.format(
list(valid), list(buffer))
assert value == 2**127, 'Example large integer should be {} but is {}'.format(2**127, value)
@staticmethod
def test_binary_tgwriter_tgreader():
small_data = os.urandom(33)
small_data_padded = os.urandom(
19) # +1 byte for length = 20 (evenly divisible by 4)
small_data_padded = os.urandom(19) # +1 byte for length = 20 (%4 = 0)
large_data = os.urandom(999)
large_data_padded = os.urandom(1024)
data = (small_data, small_data_padded, large_data, large_data_padded)
string = 'Testing Telegram strings, this should work properly!'
serialized = b''.join(TLObject.serialize_bytes(d) for d in data) + \
TLObject.serialize_bytes(string)
with BinaryWriter() as writer:
# First write the data
with BinaryReader(serialized) as reader:
# And then try reading it without errors (it should be unharmed!)
for datum in data:
writer.tgwrite_bytes(datum)
writer.tgwrite_string(string)
value = reader.tgread_bytes()
assert value == datum, 'Example bytes should be {} but is {}'.format(
datum, value)
with BinaryReader(writer.get_bytes()) as reader:
# And then try reading it without errors (it should be unharmed!)
for datum in data:
value = reader.tgread_bytes()
assert value == datum, 'Example bytes should be {} but is {}'.format(
datum, value)
value = reader.tgread_string()
assert value == string, 'Example string should be {} but is {}'.format(
string, value)
value = reader.tgread_string()
assert value == string, 'Example string should be {} but is {}'.format(
string, value)