mirror of
https://github.com/LonamiWebs/Telethon.git
synced 2025-08-04 20:20:23 +03:00
Merge branch 'master' of github.com:lonamiwebs/Telethon
This commit is contained in:
commit
a6e295da65
29
.github/ISSUE_TEMPLATE.md
vendored
29
.github/ISSUE_TEMPLATE.md
vendored
|
@ -1,26 +1,9 @@
|
|||
<!--
|
||||
Please remember that issues here should be related to the library itself and NOT your code.
|
||||
0. The library is Python 3.x, not Python 2.x.
|
||||
1. If you're posting an issue, make sure it's a bug in the library, not in your code.
|
||||
2. If you're posting a question, make sure you have read and tried enough things first.
|
||||
3. Show as much information as possible, including your failed attempts, and the full console output (to include the whole traceback with line numbers).
|
||||
4. Good looking issues are a lot more appealing. If you need help check out https://guides.github.com/features/mastering-markdown/.
|
||||
|
||||
Python 2 is NOT supported. Make sure you're using the latest version of Telethon before reporting:
|
||||
pip install telethon --upgrade
|
||||
|
||||
Some questions are okay, but make sure you've investigated enough on your own before or you will end up on the Wall of Shame:
|
||||
https://github.com/LonamiWebs/Telethon/wiki/Wall-of-Shame.
|
||||
You may also want to watch "How (not) to ask a technical question" over https://youtu.be/53zkBvL4ZB4
|
||||
-->
|
||||
|
||||
### What went wrong
|
||||
Describe what happened or what the error you have is.
|
||||
|
||||
```
|
||||
// paste the crash log here if any
|
||||
```
|
||||
|
||||
### What I've done
|
||||
Either a code example of what you were trying to do, or steps to reproduce, or methods you have tried invoking.
|
||||
|
||||
```python
|
||||
# Add your Python code here
|
||||
```
|
||||
|
||||
### More information
|
||||
If you think other information can be relevant (e.g. operative system or other variables), add it here.
|
||||
|
|
|
@ -199,14 +199,27 @@ def generate_index(folder, original_paths):
|
|||
def get_description(arg):
|
||||
"""Generates a proper description for the given argument"""
|
||||
desc = []
|
||||
otherwise = False
|
||||
if arg.can_be_inferred:
|
||||
desc.append('If left to None, it will be inferred automatically.')
|
||||
if arg.is_vector:
|
||||
desc.append('A list must be supplied for this argument.')
|
||||
if arg.is_generic:
|
||||
desc.append('A different Request must be supplied for this argument.')
|
||||
if arg.is_flag:
|
||||
desc.append('If left unspecified, it will be inferred automatically.')
|
||||
otherwise = True
|
||||
elif arg.is_flag:
|
||||
desc.append('This argument can be omitted.')
|
||||
otherwise = True
|
||||
|
||||
if arg.is_vector:
|
||||
if arg.is_generic:
|
||||
desc.append('A list of other Requests must be supplied.')
|
||||
else:
|
||||
desc.append('A list must be supplied.')
|
||||
elif arg.is_generic:
|
||||
desc.append('A different Request must be supplied for this argument.')
|
||||
else:
|
||||
otherwise = False # Always reset to false if no other text is added
|
||||
|
||||
if otherwise:
|
||||
desc.insert(1, 'Otherwise,')
|
||||
desc[-1] = desc[-1][:1].lower() + desc[-1][1:]
|
||||
|
||||
return ' '.join(desc)
|
||||
|
||||
|
@ -218,6 +231,7 @@ def generate_documentation(scheme_file):
|
|||
original_paths = {
|
||||
'css': 'css/docs.css',
|
||||
'arrow': 'img/arrow.svg',
|
||||
'404': '404.html',
|
||||
'index_all': 'index.html',
|
||||
'index_types': 'types/index.html',
|
||||
'index_methods': 'methods/index.html',
|
||||
|
@ -360,7 +374,8 @@ def generate_documentation(scheme_file):
|
|||
for tltype, constructors in tltypes.items():
|
||||
filename = get_path_for_type(tltype)
|
||||
out_dir = os.path.dirname(filename)
|
||||
os.makedirs(out_dir, exist_ok=True)
|
||||
if out_dir:
|
||||
os.makedirs(out_dir, exist_ok=True)
|
||||
|
||||
# Since we don't have access to the full TLObject, split the type
|
||||
if '.' in tltype:
|
||||
|
@ -503,15 +518,26 @@ def generate_documentation(scheme_file):
|
|||
methods = sorted(methods, key=lambda m: m.name)
|
||||
constructors = sorted(constructors, key=lambda c: c.name)
|
||||
|
||||
def fmt(xs):
|
||||
ys = {x: get_class_name(x) for x in xs} # cache TLObject: display
|
||||
zs = {} # create a dict to hold those which have duplicated keys
|
||||
for y in ys.values():
|
||||
zs[y] = y in zs
|
||||
return ', '.join(
|
||||
'"{}.{}"'.format(x.namespace, ys[x])
|
||||
if zs[ys[x]] and getattr(x, 'namespace', None)
|
||||
else '"{}"'.format(ys[x]) for x in xs
|
||||
)
|
||||
|
||||
request_names = fmt(methods)
|
||||
type_names = fmt(types)
|
||||
constructor_names = fmt(constructors)
|
||||
|
||||
def fmt(xs, formatter):
|
||||
return ', '.join('"{}"'.format(formatter(x)) for x in xs)
|
||||
|
||||
request_names = fmt(methods, get_class_name)
|
||||
type_names = fmt(types, get_class_name)
|
||||
constructor_names = fmt(constructors, get_class_name)
|
||||
|
||||
request_urls = fmt(methods, get_create_path_for)
|
||||
type_urls = fmt(types, get_create_path_for)
|
||||
type_urls = fmt(types, get_path_for_type)
|
||||
constructor_urls = fmt(constructors, get_create_path_for)
|
||||
|
||||
replace_dict = {
|
||||
|
@ -528,13 +554,15 @@ def generate_documentation(scheme_file):
|
|||
'constructor_urls': constructor_urls
|
||||
}
|
||||
|
||||
with open('../res/core.html') as infile:
|
||||
with open(original_paths['index_all'], 'w') as outfile:
|
||||
text = infile.read()
|
||||
for key, value in replace_dict.items():
|
||||
text = text.replace('{' + key + '}', str(value))
|
||||
shutil.copy('../res/404.html', original_paths['404'])
|
||||
|
||||
outfile.write(text)
|
||||
with open('../res/core.html') as infile,\
|
||||
open(original_paths['index_all'], 'w') as outfile:
|
||||
text = infile.read()
|
||||
for key, value in replace_dict.items():
|
||||
text = text.replace('{' + key + '}', str(value))
|
||||
|
||||
outfile.write(text)
|
||||
|
||||
# Everything done
|
||||
print('Documentation generated.')
|
||||
|
@ -551,5 +579,8 @@ def copy_resources():
|
|||
if __name__ == '__main__':
|
||||
os.makedirs('generated', exist_ok=True)
|
||||
os.chdir('generated')
|
||||
generate_documentation('../../telethon_generator/scheme.tl')
|
||||
copy_resources()
|
||||
try:
|
||||
generate_documentation('../../telethon_generator/scheme.tl')
|
||||
copy_resources()
|
||||
finally:
|
||||
os.chdir(os.pardir)
|
||||
|
|
44
docs/res/404.html
Normal file
44
docs/res/404.html
Normal file
|
@ -0,0 +1,44 @@
|
|||
<!DOCTYPE html>
|
||||
<html><head>
|
||||
<title>Oopsie! | Telethon</title>
|
||||
|
||||
<meta charset="utf-8">
|
||||
<meta http-equiv="Content-type" content="text/html; charset=utf-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1">
|
||||
<style type="text/css">
|
||||
body {
|
||||
background-color: #f0f4f8;
|
||||
font-family: "Open Sans", "Helvetica Neue", Helvetica, Arial, sans-serif;
|
||||
}
|
||||
div {
|
||||
width: 560px;
|
||||
margin: 5em auto;
|
||||
padding: 50px;
|
||||
background-color: #fff;
|
||||
border-radius: 1em;
|
||||
}
|
||||
a:link, a:visited {
|
||||
color: #38488f;
|
||||
text-decoration: none;
|
||||
}
|
||||
@media (max-width: 700px) {
|
||||
body {
|
||||
background-color: #fff;
|
||||
}
|
||||
div {
|
||||
width: auto;
|
||||
margin: 0 auto;
|
||||
border-radius: 0;
|
||||
padding: 1em;
|
||||
}
|
||||
}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<div>
|
||||
<h1>You seem a bit lost…</h1>
|
||||
<p>You seem to be lost! Don't worry, that's just Telegram's API being
|
||||
itself. Shall we go back to the <a href="index.html">Main Page</a>?</p>
|
||||
</div>
|
||||
</body>
|
||||
</html>
|
74
setup.py
74
setup.py
|
@ -12,34 +12,53 @@ Extra supported commands are:
|
|||
"""
|
||||
|
||||
# To use a consistent encoding
|
||||
from subprocess import run
|
||||
from shutil import rmtree
|
||||
from codecs import open
|
||||
from sys import argv
|
||||
from os import path
|
||||
import os
|
||||
|
||||
# Always prefer setuptools over distutils
|
||||
from setuptools import find_packages, setup
|
||||
|
||||
try:
|
||||
from telethon import TelegramClient
|
||||
except ImportError:
|
||||
except Exception as e:
|
||||
print('Failed to import TelegramClient due to', e)
|
||||
TelegramClient = None
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if len(argv) >= 2 and argv[1] == 'gen_tl':
|
||||
from telethon_generator.tl_generator import TLGenerator
|
||||
generator = TLGenerator('telethon/tl')
|
||||
if generator.tlobjects_exist():
|
||||
print('Detected previous TLObjects. Cleaning...')
|
||||
generator.clean_tlobjects()
|
||||
class TempWorkDir:
|
||||
"""Switches the working directory to be the one on which this file lives,
|
||||
while within the 'with' block.
|
||||
"""
|
||||
def __init__(self):
|
||||
self.original = None
|
||||
|
||||
print('Generating TLObjects...')
|
||||
generator.generate_tlobjects(
|
||||
'telethon_generator/scheme.tl', import_depth=2
|
||||
)
|
||||
print('Done.')
|
||||
def __enter__(self):
|
||||
self.original = os.path.abspath(os.path.curdir)
|
||||
os.chdir(os.path.abspath(os.path.dirname(__file__)))
|
||||
return self
|
||||
|
||||
def __exit__(self, *args):
|
||||
os.chdir(self.original)
|
||||
|
||||
|
||||
def gen_tl():
|
||||
from telethon_generator.tl_generator import TLGenerator
|
||||
generator = TLGenerator('telethon/tl')
|
||||
if generator.tlobjects_exist():
|
||||
print('Detected previous TLObjects. Cleaning...')
|
||||
generator.clean_tlobjects()
|
||||
|
||||
print('Generating TLObjects...')
|
||||
generator.generate_tlobjects(
|
||||
'telethon_generator/scheme.tl', import_depth=2
|
||||
)
|
||||
print('Done.')
|
||||
|
||||
|
||||
def main():
|
||||
if len(argv) >= 2 and argv[1] == 'gen_tl':
|
||||
gen_tl()
|
||||
|
||||
elif len(argv) >= 2 and argv[1] == 'clean_tl':
|
||||
from telethon_generator.tl_generator import TLGenerator
|
||||
|
@ -48,6 +67,11 @@ if __name__ == '__main__':
|
|||
print('Done.')
|
||||
|
||||
elif len(argv) >= 2 and argv[1] == 'pypi':
|
||||
# Need python3.5 or higher, but Telethon is supposed to support 3.x
|
||||
# Place it here since noone should be running ./setup.py pypi anyway
|
||||
from subprocess import run
|
||||
from shutil import rmtree
|
||||
|
||||
for x in ('build', 'dist', 'Telethon.egg-info'):
|
||||
rmtree(x, ignore_errors=True)
|
||||
run('python3 setup.py sdist', shell=True)
|
||||
|
@ -58,20 +82,21 @@ if __name__ == '__main__':
|
|||
|
||||
else:
|
||||
if not TelegramClient:
|
||||
print('Run `python3', argv[0], 'gen_tl` first.')
|
||||
quit()
|
||||
|
||||
here = path.abspath(path.dirname(__file__))
|
||||
gen_tl()
|
||||
from telethon import TelegramClient as TgClient
|
||||
version = TgClient.__version__
|
||||
else:
|
||||
version = TelegramClient.__version__
|
||||
|
||||
# Get the long description from the README file
|
||||
with open(path.join(here, 'README.rst'), encoding='utf-8') as f:
|
||||
with open('README.rst', encoding='utf-8') as f:
|
||||
long_description = f.read()
|
||||
|
||||
setup(
|
||||
name='Telethon',
|
||||
|
||||
# Versions should comply with PEP440.
|
||||
version=TelegramClient.__version__,
|
||||
version=version,
|
||||
description="Full-featured Telegram client library for Python 3",
|
||||
long_description=long_description,
|
||||
|
||||
|
@ -108,3 +133,8 @@ if __name__ == '__main__':
|
|||
]),
|
||||
install_requires=['pyaes', 'rsa']
|
||||
)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
with TempWorkDir(): # Could just use a try/finally but this is + reusable
|
||||
main()
|
||||
|
|
|
@ -1,7 +1,8 @@
|
|||
import struct
|
||||
from hashlib import sha1
|
||||
|
||||
from .. import helpers as utils
|
||||
from ..extensions import BinaryReader, BinaryWriter
|
||||
from ..extensions import BinaryReader
|
||||
|
||||
|
||||
class AuthKey:
|
||||
|
@ -17,10 +18,6 @@ class AuthKey:
|
|||
"""Calculates the new nonce hash based on
|
||||
the current class fields' values
|
||||
"""
|
||||
with BinaryWriter() as writer:
|
||||
writer.write(new_nonce)
|
||||
writer.write_byte(number)
|
||||
writer.write_long(self.aux_hash, signed=False)
|
||||
|
||||
new_nonce_hash = utils.calc_msg_key(writer.get_bytes())
|
||||
return new_nonce_hash
|
||||
new_nonce = new_nonce.to_bytes(32, 'little', signed=True)
|
||||
data = new_nonce + struct.pack('<BQ', number, self.aux_hash)
|
||||
return utils.calc_msg_key(data)
|
||||
|
|
|
@ -10,56 +10,35 @@ from ..errors import CdnFileTamperedError
|
|||
class CdnDecrypter:
|
||||
"""Used when downloading a file results in a 'FileCdnRedirect' to
|
||||
both prepare the redirect, decrypt the file as it downloads, and
|
||||
ensure the file hasn't been tampered.
|
||||
ensure the file hasn't been tampered. https://core.telegram.org/cdn
|
||||
"""
|
||||
def __init__(self, cdn_client, file_token, cdn_aes, cdn_file_hashes):
|
||||
self.client = cdn_client
|
||||
self.file_token = file_token
|
||||
self.cdn_aes = cdn_aes
|
||||
self.cdn_file_hashes = cdn_file_hashes
|
||||
self.shaes = [sha256() for _ in range(len(cdn_file_hashes))]
|
||||
|
||||
@staticmethod
|
||||
def prepare_decrypter(client, client_cls, cdn_redirect):
|
||||
def prepare_decrypter(client, cdn_client, cdn_redirect):
|
||||
"""Prepares a CDN decrypter, returning (decrypter, file data).
|
||||
'client' should be the original TelegramBareClient that
|
||||
tried to download the file.
|
||||
|
||||
'client_cls' should be the class of the TelegramBareClient.
|
||||
'client' should be an existing client not connected to a CDN.
|
||||
'cdn_client' should be an already-connected TelegramBareClient
|
||||
with the auth key already created.
|
||||
"""
|
||||
# TODO Avoid the need for 'client_cls=TelegramBareClient'
|
||||
# https://core.telegram.org/cdn
|
||||
cdn_aes = AESModeCTR(
|
||||
key=cdn_redirect.encryption_key,
|
||||
# 12 first bytes of the IV..4 bytes of the offset (0, big endian)
|
||||
iv=cdn_redirect.encryption_iv[:12] + bytes(4)
|
||||
)
|
||||
|
||||
# Create a new client on said CDN
|
||||
dc = client._get_dc(cdn_redirect.dc_id, cdn=True)
|
||||
session = Session(client.session)
|
||||
session.server_address = dc.ip_address
|
||||
session.port = dc.port
|
||||
cdn_client = client_cls( # Avoid importing TelegramBareClient
|
||||
session, client.api_id, client.api_hash,
|
||||
timeout=client._timeout
|
||||
)
|
||||
# This will make use of the new RSA keys for this specific CDN.
|
||||
#
|
||||
# We assume that cdn_redirect.cdn_file_hashes are ordered by offset,
|
||||
# and that there will be enough of these to retrieve the whole file.
|
||||
#
|
||||
# This relies on the fact that TelegramBareClient._dc_options is
|
||||
# static and it won't be called from this DC (it would fail).
|
||||
cdn_client.connect()
|
||||
|
||||
# CDN client is ready, create the resulting CdnDecrypter
|
||||
decrypter = CdnDecrypter(
|
||||
cdn_client, cdn_redirect.file_token,
|
||||
cdn_aes, cdn_redirect.cdn_file_hashes
|
||||
)
|
||||
|
||||
cdn_file = client(GetCdnFileRequest(
|
||||
cdn_file = cdn_client(GetCdnFileRequest(
|
||||
file_token=cdn_redirect.file_token,
|
||||
offset=cdn_redirect.cdn_file_hashes[0].offset,
|
||||
limit=cdn_redirect.cdn_file_hashes[0].limit
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
import os
|
||||
import struct
|
||||
from hashlib import sha1
|
||||
try:
|
||||
import rsa
|
||||
|
@ -7,7 +8,7 @@ except ImportError:
|
|||
rsa = None
|
||||
raise ImportError('Missing module "rsa", please install via pip.')
|
||||
|
||||
from ..extensions import BinaryWriter
|
||||
from ..tl import TLObject
|
||||
|
||||
|
||||
# {fingerprint: Crypto.PublicKey.RSA._RSAobj} dictionary
|
||||
|
@ -34,11 +35,10 @@ def _compute_fingerprint(key):
|
|||
"""For a given Crypto.RSA key, computes its 8-bytes-long fingerprint
|
||||
in the same way that Telegram does.
|
||||
"""
|
||||
with BinaryWriter() as writer:
|
||||
writer.tgwrite_bytes(get_byte_array(key.n))
|
||||
writer.tgwrite_bytes(get_byte_array(key.e))
|
||||
# Telegram uses the last 8 bytes as the fingerprint
|
||||
return sha1(writer.get_bytes()).digest()[-8:]
|
||||
n = TLObject.serialize_bytes(get_byte_array(key.n))
|
||||
e = TLObject.serialize_bytes(get_byte_array(key.e))
|
||||
# Telegram uses the last 8 bytes as the fingerprint
|
||||
return struct.unpack('<q', sha1(n + e).digest()[-8:])[0]
|
||||
|
||||
|
||||
def add_key(pub):
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
import urllib.request
|
||||
import re
|
||||
from threading import Thread
|
||||
|
||||
from .common import (
|
||||
ReadCancelledError, InvalidParameterError, TypeNotFoundError,
|
||||
|
@ -18,21 +19,29 @@ from .rpc_errors_401 import *
|
|||
from .rpc_errors_420 import *
|
||||
|
||||
|
||||
def report_error(code, message, report_method):
|
||||
try:
|
||||
# Ensure it's signed
|
||||
report_method = int.from_bytes(
|
||||
report_method.to_bytes(4, 'big'), 'big', signed=True
|
||||
)
|
||||
url = urllib.request.urlopen(
|
||||
'https://rpc.pwrtelegram.xyz?code={}&error={}&method={}'
|
||||
.format(code, message, report_method),
|
||||
timeout=5
|
||||
)
|
||||
url.read()
|
||||
url.close()
|
||||
except:
|
||||
"We really don't want to crash when just reporting an error"
|
||||
|
||||
|
||||
def rpc_message_to_error(code, message, report_method=None):
|
||||
if report_method is not None:
|
||||
try:
|
||||
# Ensure it's signed
|
||||
report_method = int.from_bytes(
|
||||
report_method.to_bytes(4, 'big'), 'big', signed=True
|
||||
)
|
||||
url = urllib.request.urlopen(
|
||||
'https://rpc.pwrtelegram.xyz?code={}&error={}&method={}'
|
||||
.format(code, message, report_method)
|
||||
)
|
||||
url.read()
|
||||
url.close()
|
||||
except:
|
||||
"We really don't want to crash when just reporting an error"
|
||||
Thread(
|
||||
target=report_error,
|
||||
args=(code, message, report_method)
|
||||
).start()
|
||||
|
||||
errors = {
|
||||
303: rpc_errors_303_all,
|
||||
|
|
|
@ -18,6 +18,15 @@ class BotMethodInvalidError(BadRequestError):
|
|||
)
|
||||
|
||||
|
||||
class CdnMethodInvalidError(BadRequestError):
|
||||
def __init__(self, **kwargs):
|
||||
super(Exception, self).__init__(
|
||||
self,
|
||||
'This method cannot be invoked on a CDN server. Refer to '
|
||||
'https://core.telegram.org/cdn#schema for available methods.'
|
||||
)
|
||||
|
||||
|
||||
class ChannelInvalidError(BadRequestError):
|
||||
def __init__(self, **kwargs):
|
||||
super(Exception, self).__init__(
|
||||
|
@ -134,6 +143,16 @@ class InputMethodInvalidError(BadRequestError):
|
|||
)
|
||||
|
||||
|
||||
class InputRequestTooLongError(BadRequestError):
|
||||
def __init__(self, **kwargs):
|
||||
super(Exception, self).__init__(
|
||||
self,
|
||||
'The input request was too long. This may be a bug in the library '
|
||||
'as it can occur when serializing more bytes than it should (like'
|
||||
'appending the vector constructor code at the end of a message).'
|
||||
)
|
||||
|
||||
|
||||
class LastNameInvalidError(BadRequestError):
|
||||
def __init__(self, **kwargs):
|
||||
super(Exception, self).__init__(
|
||||
|
@ -142,6 +161,24 @@ class LastNameInvalidError(BadRequestError):
|
|||
)
|
||||
|
||||
|
||||
class LimitInvalidError(BadRequestError):
|
||||
def __init__(self, **kwargs):
|
||||
super(Exception, self).__init__(
|
||||
self,
|
||||
'An invalid limit was provided. See '
|
||||
'https://core.telegram.org/api/files#downloading-files'
|
||||
)
|
||||
|
||||
|
||||
class LocationInvalidError(BadRequestError):
|
||||
def __init__(self, **kwargs):
|
||||
super(Exception, self).__init__(
|
||||
self,
|
||||
'The location given for a file was invalid. See '
|
||||
'https://core.telegram.org/api/files#downloading-files'
|
||||
)
|
||||
|
||||
|
||||
class Md5ChecksumInvalidError(BadRequestError):
|
||||
def __init__(self, **kwargs):
|
||||
super(Exception, self).__init__(
|
||||
|
@ -191,6 +228,16 @@ class MsgWaitFailedError(BadRequestError):
|
|||
)
|
||||
|
||||
|
||||
class OffsetInvalidError(BadRequestError):
|
||||
def __init__(self, **kwargs):
|
||||
super(Exception, self).__init__(
|
||||
self,
|
||||
'The given offset was invalid, it must be divisible by 1KB. '
|
||||
'See https://core.telegram.org/api/files#downloading-files'
|
||||
)
|
||||
|
||||
|
||||
|
||||
class PasswordHashInvalidError(BadRequestError):
|
||||
def __init__(self, **kwargs):
|
||||
super(Exception, self).__init__(
|
||||
|
@ -350,6 +397,7 @@ class UserIdInvalidError(BadRequestError):
|
|||
rpc_errors_400_all = {
|
||||
'API_ID_INVALID': ApiIdInvalidError,
|
||||
'BOT_METHOD_INVALID': BotMethodInvalidError,
|
||||
'CDN_METHOD_INVALID': CdnMethodInvalidError,
|
||||
'CHANNEL_INVALID': ChannelInvalidError,
|
||||
'CHAT_ADMIN_REQUIRED': ChatAdminRequiredError,
|
||||
'CHAT_ID_INVALID': ChatIdInvalidError,
|
||||
|
@ -362,13 +410,17 @@ rpc_errors_400_all = {
|
|||
'FILE_PART_INVALID': FilePartInvalidError,
|
||||
'FIRSTNAME_INVALID': FirstNameInvalidError,
|
||||
'INPUT_METHOD_INVALID': InputMethodInvalidError,
|
||||
'INPUT_REQUEST_TOO_LONG': InputRequestTooLongError,
|
||||
'LASTNAME_INVALID': LastNameInvalidError,
|
||||
'LIMIT_INVALID': LimitInvalidError,
|
||||
'LOCATION_INVALID': LocationInvalidError,
|
||||
'MD5_CHECKSUM_INVALID': Md5ChecksumInvalidError,
|
||||
'MESSAGE_EMPTY': MessageEmptyError,
|
||||
'MESSAGE_ID_INVALID': MessageIdInvalidError,
|
||||
'MESSAGE_TOO_LONG': MessageTooLongError,
|
||||
'MESSAGE_NOT_MODIFIED': MessageNotModifiedError,
|
||||
'MSG_WAIT_FAILED': MsgWaitFailedError,
|
||||
'OFFSET_INVALID': OffsetInvalidError,
|
||||
'PASSWORD_HASH_INVALID': PasswordHashInvalidError,
|
||||
'PEER_ID_INVALID': PeerIdInvalidError,
|
||||
'PHONE_CODE_EMPTY': PhoneCodeEmptyError,
|
||||
|
|
|
@ -1,9 +1,7 @@
|
|||
"""
|
||||
Several extensions Python is missing, such as a proper class to handle a TCP
|
||||
communication with support for cancelling the operation, and an utility class
|
||||
to work with arbitrary binary data in a more comfortable way (writing ints,
|
||||
strings, bytes, etc.)
|
||||
to read arbitrary binary data in a more comfortable way, with int/strings/etc.
|
||||
"""
|
||||
from .binary_writer import BinaryWriter
|
||||
from .binary_reader import BinaryReader
|
||||
from .tcp_client import TcpClient
|
||||
|
|
|
@ -1,152 +0,0 @@
|
|||
from io import BufferedWriter, BytesIO, DEFAULT_BUFFER_SIZE
|
||||
from struct import pack
|
||||
|
||||
|
||||
class BinaryWriter:
|
||||
"""
|
||||
Small utility class to write binary data.
|
||||
Also creates a "Memory Stream" if necessary
|
||||
"""
|
||||
|
||||
def __init__(self, stream=None, known_length=None):
|
||||
if not stream:
|
||||
stream = BytesIO()
|
||||
|
||||
if known_length is None:
|
||||
# On some systems, DEFAULT_BUFFER_SIZE defaults to 8192
|
||||
# That's over 16 times as big as necessary for most messages
|
||||
known_length = max(DEFAULT_BUFFER_SIZE, 1024)
|
||||
|
||||
self.writer = BufferedWriter(stream, buffer_size=known_length)
|
||||
self.written_count = 0
|
||||
|
||||
# region Writing
|
||||
|
||||
# "All numbers are written as little endian."
|
||||
# https://core.telegram.org/mtproto
|
||||
def write_byte(self, value):
|
||||
"""Writes a single byte value"""
|
||||
self.writer.write(pack('B', value))
|
||||
self.written_count += 1
|
||||
|
||||
def write_int(self, value, signed=True):
|
||||
"""Writes an integer value (4 bytes), optionally signed"""
|
||||
self.writer.write(
|
||||
int.to_bytes(
|
||||
value, length=4, byteorder='little', signed=signed))
|
||||
self.written_count += 4
|
||||
|
||||
def write_long(self, value, signed=True):
|
||||
"""Writes a long integer value (8 bytes), optionally signed"""
|
||||
self.writer.write(
|
||||
int.to_bytes(
|
||||
value, length=8, byteorder='little', signed=signed))
|
||||
self.written_count += 8
|
||||
|
||||
def write_float(self, value):
|
||||
"""Writes a floating point value (4 bytes)"""
|
||||
self.writer.write(pack('<f', value))
|
||||
self.written_count += 4
|
||||
|
||||
def write_double(self, value):
|
||||
"""Writes a floating point value (8 bytes)"""
|
||||
self.writer.write(pack('<d', value))
|
||||
self.written_count += 8
|
||||
|
||||
def write_large_int(self, value, bits, signed=True):
|
||||
"""Writes a n-bits long integer value"""
|
||||
self.writer.write(
|
||||
int.to_bytes(
|
||||
value, length=bits // 8, byteorder='little', signed=signed))
|
||||
self.written_count += bits // 8
|
||||
|
||||
def write(self, data):
|
||||
"""Writes the given bytes array"""
|
||||
self.writer.write(data)
|
||||
self.written_count += len(data)
|
||||
|
||||
# endregion
|
||||
|
||||
# region Telegram custom writing
|
||||
|
||||
def tgwrite_bytes(self, data):
|
||||
"""Write bytes by using Telegram guidelines"""
|
||||
if len(data) < 254:
|
||||
padding = (len(data) + 1) % 4
|
||||
if padding != 0:
|
||||
padding = 4 - padding
|
||||
|
||||
self.write(bytes([len(data)]))
|
||||
self.write(data)
|
||||
|
||||
else:
|
||||
padding = len(data) % 4
|
||||
if padding != 0:
|
||||
padding = 4 - padding
|
||||
|
||||
self.write(bytes([254]))
|
||||
self.write(bytes([len(data) % 256]))
|
||||
self.write(bytes([(len(data) >> 8) % 256]))
|
||||
self.write(bytes([(len(data) >> 16) % 256]))
|
||||
self.write(data)
|
||||
|
||||
self.write(bytes(padding))
|
||||
|
||||
def tgwrite_string(self, string):
|
||||
"""Write a string by using Telegram guidelines"""
|
||||
self.tgwrite_bytes(string.encode('utf-8'))
|
||||
|
||||
def tgwrite_bool(self, boolean):
|
||||
"""Write a boolean value by using Telegram guidelines"""
|
||||
# boolTrue boolFalse
|
||||
self.write_int(0x997275b5 if boolean else 0xbc799737, signed=False)
|
||||
|
||||
def tgwrite_date(self, datetime):
|
||||
"""Converts a Python datetime object into Unix time
|
||||
(used by Telegram) and writes it
|
||||
"""
|
||||
value = 0 if datetime is None else int(datetime.timestamp())
|
||||
self.write_int(value)
|
||||
|
||||
def tgwrite_object(self, tlobject):
|
||||
"""Writes a Telegram object"""
|
||||
tlobject.on_send(self)
|
||||
|
||||
def tgwrite_vector(self, vector):
|
||||
"""Writes a vector of Telegram objects"""
|
||||
self.write_int(0x1cb5c415, signed=False) # Vector's constructor ID
|
||||
self.write_int(len(vector))
|
||||
for item in vector:
|
||||
self.tgwrite_object(item)
|
||||
|
||||
# endregion
|
||||
|
||||
def flush(self):
|
||||
"""Flush the current stream to "update" changes"""
|
||||
self.writer.flush()
|
||||
|
||||
def close(self):
|
||||
"""Close the current stream"""
|
||||
self.writer.close()
|
||||
|
||||
def get_bytes(self, flush=True):
|
||||
"""Get the current bytes array content from the buffer,
|
||||
optionally flushing first
|
||||
"""
|
||||
if flush:
|
||||
self.writer.flush()
|
||||
return self.writer.raw.getvalue()
|
||||
|
||||
def get_written_bytes_count(self):
|
||||
"""Gets the count of bytes written in the buffer.
|
||||
This may NOT be equal to the stream length if one
|
||||
was provided when initializing the writer
|
||||
"""
|
||||
return self.written_count
|
||||
|
||||
# with block
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
self.close()
|
|
@ -8,41 +8,55 @@ from threading import Lock
|
|||
|
||||
class TcpClient:
|
||||
def __init__(self, proxy=None, timeout=timedelta(seconds=5)):
|
||||
self._proxy = proxy
|
||||
self.proxy = proxy
|
||||
self._socket = None
|
||||
self._closing_lock = Lock()
|
||||
|
||||
if isinstance(timeout, timedelta):
|
||||
self._timeout = timeout.seconds
|
||||
self.timeout = timeout.seconds
|
||||
elif isinstance(timeout, int) or isinstance(timeout, float):
|
||||
self._timeout = float(timeout)
|
||||
self.timeout = float(timeout)
|
||||
else:
|
||||
raise ValueError('Invalid timeout type', type(timeout))
|
||||
|
||||
def _recreate_socket(self, mode):
|
||||
if self._proxy is None:
|
||||
if self.proxy is None:
|
||||
self._socket = socket.socket(mode, socket.SOCK_STREAM)
|
||||
else:
|
||||
import socks
|
||||
self._socket = socks.socksocket(mode, socket.SOCK_STREAM)
|
||||
if type(self._proxy) is dict:
|
||||
self._socket.set_proxy(**self._proxy)
|
||||
if type(self.proxy) is dict:
|
||||
self._socket.set_proxy(**self.proxy)
|
||||
else: # tuple, list, etc.
|
||||
self._socket.set_proxy(*self._proxy)
|
||||
self._socket.set_proxy(*self.proxy)
|
||||
|
||||
self._socket.settimeout(self.timeout)
|
||||
|
||||
def connect(self, ip, port):
|
||||
"""Connects to the specified IP and port number.
|
||||
'timeout' must be given in seconds
|
||||
"""
|
||||
if not self.connected:
|
||||
if ':' in ip: # IPv6
|
||||
mode, address = socket.AF_INET6, (ip, port, 0, 0)
|
||||
else:
|
||||
mode, address = socket.AF_INET, (ip, port)
|
||||
if ':' in ip: # IPv6
|
||||
mode, address = socket.AF_INET6, (ip, port, 0, 0)
|
||||
else:
|
||||
mode, address = socket.AF_INET, (ip, port)
|
||||
|
||||
self._recreate_socket(mode)
|
||||
self._socket.settimeout(self._timeout)
|
||||
self._socket.connect(address)
|
||||
while True:
|
||||
try:
|
||||
while not self._socket:
|
||||
self._recreate_socket(mode)
|
||||
|
||||
self._socket.connect(address)
|
||||
break # Successful connection, stop retrying to connect
|
||||
except OSError as e:
|
||||
# There are some errors that we know how to handle, and
|
||||
# the loop will allow us to retry
|
||||
if e.errno == errno.EBADF:
|
||||
# Bad file descriptor, i.e. socket was closed, set it
|
||||
# to none to recreate it on the next iteration
|
||||
self._socket = None
|
||||
else:
|
||||
raise
|
||||
|
||||
def _get_connected(self):
|
||||
return self._socket is not None
|
||||
|
@ -67,6 +81,8 @@ class TcpClient:
|
|||
|
||||
def write(self, data):
|
||||
"""Writes (sends) the specified bytes to the connected peer"""
|
||||
if self._socket is None:
|
||||
raise ConnectionResetError()
|
||||
|
||||
# TODO Timeout may be an issue when sending the data, Changed in v3.5:
|
||||
# The socket timeout is now the maximum total duration to send all data.
|
||||
|
@ -74,13 +90,13 @@ class TcpClient:
|
|||
self._socket.sendall(data)
|
||||
except socket.timeout as e:
|
||||
raise TimeoutError() from e
|
||||
except BrokenPipeError:
|
||||
self._raise_connection_reset()
|
||||
except OSError as e:
|
||||
if e.errno == errno.EBADF:
|
||||
self._raise_connection_reset()
|
||||
else:
|
||||
raise
|
||||
except BrokenPipeError:
|
||||
self._raise_connection_reset()
|
||||
|
||||
def read(self, size):
|
||||
"""Reads (receives) a whole block of 'size bytes
|
||||
|
@ -91,6 +107,9 @@ class TcpClient:
|
|||
and it's waiting for more, the timeout will NOT cancel the
|
||||
operation. Set to None for no timeout
|
||||
"""
|
||||
if self._socket is None:
|
||||
raise ConnectionResetError()
|
||||
|
||||
# TODO Remove the timeout from this method, always use previous one
|
||||
with BufferedWriter(BytesIO(), buffer_size=size) as buffer:
|
||||
bytes_left = size
|
||||
|
@ -100,7 +119,7 @@ class TcpClient:
|
|||
except socket.timeout as e:
|
||||
raise TimeoutError() from e
|
||||
except OSError as e:
|
||||
if e.errno == errno.EBADF:
|
||||
if e.errno == errno.EBADF or e.errno == errno.ENOTSOCK:
|
||||
self._raise_connection_reset()
|
||||
else:
|
||||
raise
|
||||
|
|
|
@ -47,9 +47,11 @@ def calc_msg_key(data):
|
|||
|
||||
def generate_key_data_from_nonce(server_nonce, new_nonce):
|
||||
"""Generates the key data corresponding to the given nonce"""
|
||||
hash1 = sha1(bytes(new_nonce + server_nonce)).digest()
|
||||
hash2 = sha1(bytes(server_nonce + new_nonce)).digest()
|
||||
hash3 = sha1(bytes(new_nonce + new_nonce)).digest()
|
||||
server_nonce = server_nonce.to_bytes(16, 'little', signed=True)
|
||||
new_nonce = new_nonce.to_bytes(32, 'little', signed=True)
|
||||
hash1 = sha1(new_nonce + server_nonce).digest()
|
||||
hash2 = sha1(server_nonce + new_nonce).digest()
|
||||
hash3 = sha1(new_nonce + new_nonce).digest()
|
||||
|
||||
key = hash1 + hash2[:12]
|
||||
iv = hash2[12:20] + hash3 + new_nonce[:4]
|
||||
|
|
|
@ -2,12 +2,19 @@ import os
|
|||
import time
|
||||
from hashlib import sha1
|
||||
|
||||
from ..tl.types import (
|
||||
ResPQ, PQInnerData, ServerDHParamsFail, ServerDHParamsOk,
|
||||
ServerDHInnerData, ClientDHInnerData, DhGenOk, DhGenRetry, DhGenFail
|
||||
)
|
||||
from .. import helpers as utils
|
||||
from ..crypto import AES, AuthKey, Factorization
|
||||
from ..crypto import rsa
|
||||
from ..errors import SecurityError, TypeNotFoundError
|
||||
from ..extensions import BinaryReader, BinaryWriter
|
||||
from ..errors import SecurityError
|
||||
from ..extensions import BinaryReader
|
||||
from ..network import MtProtoPlainSender
|
||||
from ..tl.functions import (
|
||||
ReqPqRequest, ReqDHParamsRequest, SetClientDHParamsRequest
|
||||
)
|
||||
|
||||
|
||||
def do_authentication(connection, retries=5):
|
||||
|
@ -18,7 +25,7 @@ def do_authentication(connection, retries=5):
|
|||
while retries:
|
||||
try:
|
||||
return _do_authentication(connection)
|
||||
except (SecurityError, TypeNotFoundError, NotImplementedError) as e:
|
||||
except (SecurityError, AssertionError, NotImplementedError) as e:
|
||||
last_error = e
|
||||
retries -= 1
|
||||
raise last_error
|
||||
|
@ -30,202 +37,158 @@ def _do_authentication(connection):
|
|||
time offset.
|
||||
"""
|
||||
sender = MtProtoPlainSender(connection)
|
||||
sender.connect()
|
||||
|
||||
# Step 1 sending: PQ Request
|
||||
nonce = os.urandom(16)
|
||||
with BinaryWriter(known_length=20) as writer:
|
||||
writer.write_int(0x60469778, signed=False) # Constructor number
|
||||
writer.write(nonce)
|
||||
sender.send(writer.get_bytes())
|
||||
|
||||
# Step 1 response: PQ Request
|
||||
pq, pq_bytes, server_nonce, fingerprints = None, None, None, []
|
||||
# Step 1 sending: PQ Request, endianness doesn't matter since it's random
|
||||
req_pq_request = ReqPqRequest(
|
||||
nonce=int.from_bytes(os.urandom(16), 'big', signed=True)
|
||||
)
|
||||
sender.send(req_pq_request.to_bytes())
|
||||
with BinaryReader(sender.receive()) as reader:
|
||||
response_code = reader.read_int(signed=False)
|
||||
if response_code != 0x05162463:
|
||||
raise TypeNotFoundError(response_code)
|
||||
req_pq_request.on_response(reader)
|
||||
|
||||
nonce_from_server = reader.read(16)
|
||||
if nonce_from_server != nonce:
|
||||
raise SecurityError('Invalid nonce from server')
|
||||
res_pq = req_pq_request.result
|
||||
if not isinstance(res_pq, ResPQ):
|
||||
raise AssertionError(res_pq)
|
||||
|
||||
server_nonce = reader.read(16)
|
||||
if res_pq.nonce != req_pq_request.nonce:
|
||||
raise SecurityError('Invalid nonce from server')
|
||||
|
||||
pq_bytes = reader.tgread_bytes()
|
||||
pq = get_int(pq_bytes)
|
||||
|
||||
vector_id = reader.read_int()
|
||||
if vector_id != 0x1cb5c415:
|
||||
raise TypeNotFoundError(response_code)
|
||||
|
||||
fingerprints = []
|
||||
fingerprint_count = reader.read_int()
|
||||
for _ in range(fingerprint_count):
|
||||
fingerprints.append(reader.read(8))
|
||||
pq = get_int(res_pq.pq)
|
||||
|
||||
# Step 2 sending: DH Exchange
|
||||
new_nonce = os.urandom(32)
|
||||
p, q = Factorization.factorize(pq)
|
||||
with BinaryWriter() as pq_inner_data_writer:
|
||||
pq_inner_data_writer.write_int(
|
||||
0x83c95aec, signed=False) # PQ Inner Data
|
||||
pq_inner_data_writer.tgwrite_bytes(rsa.get_byte_array(pq))
|
||||
pq_inner_data_writer.tgwrite_bytes(rsa.get_byte_array(min(p, q)))
|
||||
pq_inner_data_writer.tgwrite_bytes(rsa.get_byte_array(max(p, q)))
|
||||
pq_inner_data_writer.write(nonce)
|
||||
pq_inner_data_writer.write(server_nonce)
|
||||
pq_inner_data_writer.write(new_nonce)
|
||||
p, q = rsa.get_byte_array(min(p, q)), rsa.get_byte_array(max(p, q))
|
||||
new_nonce = int.from_bytes(os.urandom(32), 'little', signed=True)
|
||||
|
||||
# sha_digest + data + random_bytes
|
||||
cipher_text, target_fingerprint = None, None
|
||||
for fingerprint in fingerprints:
|
||||
cipher_text = rsa.encrypt(
|
||||
fingerprint,
|
||||
pq_inner_data_writer.get_bytes()
|
||||
pq_inner_data = PQInnerData(
|
||||
pq=rsa.get_byte_array(pq), p=p, q=q,
|
||||
nonce=res_pq.nonce,
|
||||
server_nonce=res_pq.server_nonce,
|
||||
new_nonce=new_nonce
|
||||
).to_bytes()
|
||||
|
||||
# sha_digest + data + random_bytes
|
||||
cipher_text, target_fingerprint = None, None
|
||||
for fingerprint in res_pq.server_public_key_fingerprints:
|
||||
cipher_text = rsa.encrypt(fingerprint, pq_inner_data)
|
||||
if cipher_text is not None:
|
||||
target_fingerprint = fingerprint
|
||||
break
|
||||
|
||||
if cipher_text is None:
|
||||
raise SecurityError(
|
||||
'Could not find a valid key for fingerprints: {}'
|
||||
.format(', '.join(
|
||||
[str(f) for f in res_pq.server_public_key_fingerprints])
|
||||
)
|
||||
)
|
||||
|
||||
if cipher_text is not None:
|
||||
target_fingerprint = fingerprint
|
||||
break
|
||||
|
||||
if cipher_text is None:
|
||||
raise SecurityError(
|
||||
'Could not find a valid key for fingerprints: {}'
|
||||
.format(', '.join([repr(f) for f in fingerprints]))
|
||||
)
|
||||
|
||||
with BinaryWriter() as req_dh_params_writer:
|
||||
req_dh_params_writer.write_int(
|
||||
0xd712e4be, signed=False) # Req DH Params
|
||||
req_dh_params_writer.write(nonce)
|
||||
req_dh_params_writer.write(server_nonce)
|
||||
req_dh_params_writer.tgwrite_bytes(rsa.get_byte_array(min(p, q)))
|
||||
req_dh_params_writer.tgwrite_bytes(rsa.get_byte_array(max(p, q)))
|
||||
req_dh_params_writer.write(target_fingerprint)
|
||||
req_dh_params_writer.tgwrite_bytes(cipher_text)
|
||||
|
||||
req_dh_params_bytes = req_dh_params_writer.get_bytes()
|
||||
sender.send(req_dh_params_bytes)
|
||||
req_dh_params = ReqDHParamsRequest(
|
||||
nonce=res_pq.nonce,
|
||||
server_nonce=res_pq.server_nonce,
|
||||
p=p, q=q,
|
||||
public_key_fingerprint=target_fingerprint,
|
||||
encrypted_data=cipher_text
|
||||
)
|
||||
sender.send(req_dh_params.to_bytes())
|
||||
|
||||
# Step 2 response: DH Exchange
|
||||
encrypted_answer = None
|
||||
with BinaryReader(sender.receive()) as reader:
|
||||
response_code = reader.read_int(signed=False)
|
||||
req_dh_params.on_response(reader)
|
||||
|
||||
if response_code == 0x79cb045d:
|
||||
raise SecurityError('Server DH params fail: TODO')
|
||||
server_dh_params = req_dh_params.result
|
||||
if isinstance(server_dh_params, ServerDHParamsFail):
|
||||
raise SecurityError('Server DH params fail: TODO')
|
||||
|
||||
if response_code != 0xd0e8075c:
|
||||
raise TypeNotFoundError(response_code)
|
||||
if not isinstance(server_dh_params, ServerDHParamsOk):
|
||||
raise AssertionError(server_dh_params)
|
||||
|
||||
nonce_from_server = reader.read(16)
|
||||
if nonce_from_server != nonce:
|
||||
raise SecurityError('Invalid nonce from server')
|
||||
if server_dh_params.nonce != res_pq.nonce:
|
||||
raise SecurityError('Invalid nonce from server')
|
||||
|
||||
server_nonce_from_server = reader.read(16)
|
||||
if server_nonce_from_server != server_nonce:
|
||||
raise SecurityError('Invalid server nonce from server')
|
||||
|
||||
encrypted_answer = reader.tgread_bytes()
|
||||
if server_dh_params.server_nonce != res_pq.server_nonce:
|
||||
raise SecurityError('Invalid server nonce from server')
|
||||
|
||||
# Step 3 sending: Complete DH Exchange
|
||||
key, iv = utils.generate_key_data_from_nonce(server_nonce, new_nonce)
|
||||
plain_text_answer = AES.decrypt_ige(encrypted_answer, key, iv)
|
||||
key, iv = utils.generate_key_data_from_nonce(
|
||||
res_pq.server_nonce, new_nonce
|
||||
)
|
||||
plain_text_answer = AES.decrypt_ige(
|
||||
server_dh_params.encrypted_answer, key, iv
|
||||
)
|
||||
|
||||
g, dh_prime, ga, time_offset = None, None, None, None
|
||||
with BinaryReader(plain_text_answer) as dh_inner_data_reader:
|
||||
dh_inner_data_reader.read(20) # hash sum
|
||||
code = dh_inner_data_reader.read_int(signed=False)
|
||||
if code != 0xb5890dba:
|
||||
raise TypeNotFoundError(code)
|
||||
with BinaryReader(plain_text_answer) as reader:
|
||||
reader.read(20) # hash sum
|
||||
server_dh_inner = reader.tgread_object()
|
||||
if not isinstance(server_dh_inner, ServerDHInnerData):
|
||||
raise AssertionError(server_dh_inner)
|
||||
|
||||
nonce_from_server1 = dh_inner_data_reader.read(16)
|
||||
if nonce_from_server1 != nonce:
|
||||
raise SecurityError('Invalid nonce in encrypted answer')
|
||||
if server_dh_inner.nonce != res_pq.nonce:
|
||||
print(server_dh_inner.nonce, res_pq.nonce)
|
||||
raise SecurityError('Invalid nonce in encrypted answer')
|
||||
|
||||
server_nonce_from_server1 = dh_inner_data_reader.read(16)
|
||||
if server_nonce_from_server1 != server_nonce:
|
||||
raise SecurityError('Invalid server nonce in encrypted answer')
|
||||
if server_dh_inner.server_nonce != res_pq.server_nonce:
|
||||
raise SecurityError('Invalid server nonce in encrypted answer')
|
||||
|
||||
g = dh_inner_data_reader.read_int()
|
||||
dh_prime = get_int(dh_inner_data_reader.tgread_bytes(), signed=False)
|
||||
ga = get_int(dh_inner_data_reader.tgread_bytes(), signed=False)
|
||||
|
||||
server_time = dh_inner_data_reader.read_int()
|
||||
time_offset = server_time - int(time.time())
|
||||
dh_prime = get_int(server_dh_inner.dh_prime, signed=False)
|
||||
g_a = get_int(server_dh_inner.g_a, signed=False)
|
||||
time_offset = server_dh_inner.server_time - int(time.time())
|
||||
|
||||
b = get_int(os.urandom(256), signed=False)
|
||||
gb = pow(g, b, dh_prime)
|
||||
gab = pow(ga, b, dh_prime)
|
||||
gb = pow(server_dh_inner.g, b, dh_prime)
|
||||
gab = pow(g_a, b, dh_prime)
|
||||
|
||||
# Prepare client DH Inner Data
|
||||
with BinaryWriter() as client_dh_inner_data_writer:
|
||||
client_dh_inner_data_writer.write_int(
|
||||
0x6643b654, signed=False) # Client DH Inner Data
|
||||
client_dh_inner_data_writer.write(nonce)
|
||||
client_dh_inner_data_writer.write(server_nonce)
|
||||
client_dh_inner_data_writer.write_long(0) # TODO retry_id
|
||||
client_dh_inner_data_writer.tgwrite_bytes(rsa.get_byte_array(gb))
|
||||
client_dh_inner = ClientDHInnerData(
|
||||
nonce=res_pq.nonce,
|
||||
server_nonce=res_pq.server_nonce,
|
||||
retry_id=0, # TODO Actual retry ID
|
||||
g_b=rsa.get_byte_array(gb)
|
||||
).to_bytes()
|
||||
|
||||
with BinaryWriter() as client_dh_inner_data_with_hash_writer:
|
||||
client_dh_inner_data_with_hash_writer.write(
|
||||
sha1(client_dh_inner_data_writer.get_bytes()).digest())
|
||||
|
||||
client_dh_inner_data_with_hash_writer.write(
|
||||
client_dh_inner_data_writer.get_bytes())
|
||||
|
||||
client_dh_inner_data_bytes = \
|
||||
client_dh_inner_data_with_hash_writer.get_bytes()
|
||||
client_dh_inner_hashed = sha1(client_dh_inner).digest() + client_dh_inner
|
||||
|
||||
# Encryption
|
||||
client_dh_inner_data_encrypted_bytes = AES.encrypt_ige(
|
||||
client_dh_inner_data_bytes, key, iv)
|
||||
client_dh_encrypted = AES.encrypt_ige(client_dh_inner_hashed, key, iv)
|
||||
|
||||
# Prepare Set client DH params
|
||||
with BinaryWriter() as set_client_dh_params_writer:
|
||||
set_client_dh_params_writer.write_int(0xf5045f1f, signed=False)
|
||||
set_client_dh_params_writer.write(nonce)
|
||||
set_client_dh_params_writer.write(server_nonce)
|
||||
set_client_dh_params_writer.tgwrite_bytes(
|
||||
client_dh_inner_data_encrypted_bytes)
|
||||
|
||||
set_client_dh_params_bytes = set_client_dh_params_writer.get_bytes()
|
||||
sender.send(set_client_dh_params_bytes)
|
||||
set_client_dh = SetClientDHParamsRequest(
|
||||
nonce=res_pq.nonce,
|
||||
server_nonce=res_pq.server_nonce,
|
||||
encrypted_data=client_dh_encrypted,
|
||||
)
|
||||
sender.send(set_client_dh.to_bytes())
|
||||
|
||||
# Step 3 response: Complete DH Exchange
|
||||
with BinaryReader(sender.receive()) as reader:
|
||||
# Everything read from the server, disconnect now
|
||||
sender.disconnect()
|
||||
set_client_dh.on_response(reader)
|
||||
|
||||
code = reader.read_int(signed=False)
|
||||
if code == 0x3bcbf734: # DH Gen OK
|
||||
nonce_from_server = reader.read(16)
|
||||
if nonce_from_server != nonce:
|
||||
raise SecurityError('Invalid nonce from server')
|
||||
dh_gen = set_client_dh.result
|
||||
if isinstance(dh_gen, DhGenOk):
|
||||
if dh_gen.nonce != res_pq.nonce:
|
||||
raise SecurityError('Invalid nonce from server')
|
||||
|
||||
server_nonce_from_server = reader.read(16)
|
||||
if server_nonce_from_server != server_nonce:
|
||||
raise SecurityError('Invalid server nonce from server')
|
||||
if dh_gen.server_nonce != res_pq.server_nonce:
|
||||
raise SecurityError('Invalid server nonce from server')
|
||||
|
||||
new_nonce_hash1 = reader.read(16)
|
||||
auth_key = AuthKey(rsa.get_byte_array(gab))
|
||||
auth_key = AuthKey(rsa.get_byte_array(gab))
|
||||
new_nonce_hash = int.from_bytes(
|
||||
auth_key.calc_new_nonce_hash(new_nonce, 1), 'little', signed=True
|
||||
)
|
||||
|
||||
new_nonce_hash_calculated = auth_key.calc_new_nonce_hash(new_nonce,
|
||||
1)
|
||||
if new_nonce_hash1 != new_nonce_hash_calculated:
|
||||
raise SecurityError('Invalid new nonce hash')
|
||||
if dh_gen.new_nonce_hash1 != new_nonce_hash:
|
||||
raise SecurityError('Invalid new nonce hash')
|
||||
|
||||
return auth_key, time_offset
|
||||
return auth_key, time_offset
|
||||
|
||||
elif code == 0x46dc1fb9: # DH Gen Retry
|
||||
raise NotImplementedError('dh_gen_retry')
|
||||
elif isinstance(dh_gen, DhGenRetry):
|
||||
raise NotImplementedError('DhGenRetry')
|
||||
|
||||
elif code == 0xa69dae02: # DH Gen Fail
|
||||
raise NotImplementedError('dh_gen_fail')
|
||||
elif isinstance(dh_gen, DhGenFail):
|
||||
raise NotImplementedError('DhGenFail')
|
||||
|
||||
else:
|
||||
raise NotImplementedError('DH Gen unknown: {}'.format(hex(code)))
|
||||
else:
|
||||
raise NotImplementedError('DH Gen unknown: {}'.format(dh_gen))
|
||||
|
||||
|
||||
def get_int(byte_array, signed=True):
|
||||
|
|
|
@ -1,10 +1,13 @@
|
|||
import os
|
||||
import struct
|
||||
from datetime import timedelta
|
||||
from zlib import crc32
|
||||
from enum import Enum
|
||||
|
||||
import errno
|
||||
|
||||
from ..crypto import AESModeCTR
|
||||
from ..extensions import BinaryWriter, TcpClient
|
||||
from ..extensions import TcpClient
|
||||
from ..errors import InvalidChecksumError
|
||||
|
||||
|
||||
|
@ -75,9 +78,15 @@ class Connection:
|
|||
setattr(self, 'read', self._read_plain)
|
||||
|
||||
def connect(self):
|
||||
self._send_counter = 0
|
||||
self.conn.connect(self.ip, self.port)
|
||||
try:
|
||||
self.conn.connect(self.ip, self.port)
|
||||
except OSError as e:
|
||||
if e.errno == errno.EISCONN:
|
||||
return # Already connected, no need to re-set everything up
|
||||
else:
|
||||
raise
|
||||
|
||||
self._send_counter = 0
|
||||
if self._mode == ConnectionMode.TCP_ABRIDGED:
|
||||
self.conn.write(b'\xef')
|
||||
elif self._mode == ConnectionMode.TCP_INTERMEDIATE:
|
||||
|
@ -85,6 +94,9 @@ class Connection:
|
|||
elif self._mode == ConnectionMode.TCP_OBFUSCATED:
|
||||
self._setup_obfuscation()
|
||||
|
||||
def get_timeout(self):
|
||||
return self.conn.timeout
|
||||
|
||||
def _setup_obfuscation(self):
|
||||
# Obfuscated messages secrets cannot start with any of these
|
||||
keywords = (b'PVrG', b'GET ', b'POST', b'\xee' * 4)
|
||||
|
@ -118,6 +130,13 @@ class Connection:
|
|||
def close(self):
|
||||
self.conn.close()
|
||||
|
||||
def clone(self):
|
||||
"""Creates a copy of this Connection"""
|
||||
return Connection(self.ip, self.port,
|
||||
mode=self._mode,
|
||||
proxy=self.conn.proxy,
|
||||
timeout=self.conn.timeout)
|
||||
|
||||
# region Receive message implementations
|
||||
|
||||
def recv(self):
|
||||
|
@ -164,30 +183,22 @@ class Connection:
|
|||
# https://core.telegram.org/mtproto#tcp-transport
|
||||
# total length, sequence number, packet and checksum (CRC32)
|
||||
length = len(message) + 12
|
||||
with BinaryWriter(known_length=length) as writer:
|
||||
writer.write_int(length)
|
||||
writer.write_int(self._send_counter)
|
||||
writer.write(message)
|
||||
writer.write_int(crc32(writer.get_bytes()), signed=False)
|
||||
self._send_counter += 1
|
||||
self.write(writer.get_bytes())
|
||||
data = struct.pack('<ii', length, self._send_counter) + message
|
||||
crc = struct.pack('<I', crc32(data))
|
||||
self._send_counter += 1
|
||||
self.write(data + crc)
|
||||
|
||||
def _send_intermediate(self, message):
|
||||
with BinaryWriter(known_length=len(message) + 4) as writer:
|
||||
writer.write_int(len(message))
|
||||
writer.write(message)
|
||||
self.write(writer.get_bytes())
|
||||
self.write(struct.pack('<i', len(message)) + message)
|
||||
|
||||
def _send_abridged(self, message):
|
||||
with BinaryWriter(known_length=len(message) + 4) as writer:
|
||||
length = len(message) >> 2
|
||||
if length < 127:
|
||||
writer.write_byte(length)
|
||||
else:
|
||||
writer.write_byte(127)
|
||||
writer.write(int.to_bytes(length, 3, 'little'))
|
||||
writer.write(message)
|
||||
self.write(writer.get_bytes())
|
||||
length = len(message) >> 2
|
||||
if length < 127:
|
||||
length = struct.pack('B', length)
|
||||
else:
|
||||
length = b'\x7f' + int.to_bytes(length, 3, 'little')
|
||||
|
||||
self.write(length + message)
|
||||
|
||||
# endregion
|
||||
|
||||
|
|
|
@ -1,7 +1,8 @@
|
|||
import struct
|
||||
import time
|
||||
|
||||
from ..errors import BrokenAuthKeyError
|
||||
from ..extensions import BinaryReader, BinaryWriter
|
||||
from ..extensions import BinaryReader
|
||||
|
||||
|
||||
class MtProtoPlainSender:
|
||||
|
@ -25,14 +26,9 @@ class MtProtoPlainSender:
|
|||
"""Sends a plain packet (auth_key_id = 0) containing the
|
||||
given message body (data)
|
||||
"""
|
||||
with BinaryWriter(known_length=len(data) + 20) as writer:
|
||||
writer.write_long(0)
|
||||
writer.write_long(self._get_new_msg_id())
|
||||
writer.write_int(len(data))
|
||||
writer.write(data)
|
||||
|
||||
packet = writer.get_bytes()
|
||||
self._connection.send(packet)
|
||||
self._connection.send(
|
||||
struct.pack('<QQi', 0, self._get_new_msg_id(), len(data)) + data
|
||||
)
|
||||
|
||||
def receive(self):
|
||||
"""Receives a plain packet, returning the body of the response"""
|
||||
|
|
|
@ -1,41 +1,45 @@
|
|||
import gzip
|
||||
import logging
|
||||
from threading import RLock
|
||||
import struct
|
||||
|
||||
from .. import helpers as utils
|
||||
from ..crypto import AES
|
||||
from ..errors import BadMessageError, InvalidChecksumError, rpc_message_to_error
|
||||
from ..extensions import BinaryReader, BinaryWriter
|
||||
from ..errors import (
|
||||
BadMessageError, InvalidChecksumError, BrokenAuthKeyError,
|
||||
rpc_message_to_error
|
||||
)
|
||||
from ..extensions import BinaryReader
|
||||
from ..tl import TLMessage, MessageContainer, GzipPacked
|
||||
from ..tl.all_tlobjects import tlobjects
|
||||
from ..tl.types import MsgsAck
|
||||
from ..tl.functions.auth import LogOutRequest
|
||||
|
||||
logging.getLogger(__name__).addHandler(logging.NullHandler())
|
||||
|
||||
|
||||
class MtProtoSender:
|
||||
"""MTProto Mobile Protocol sender
|
||||
(https://core.telegram.org/mtproto/description)
|
||||
(https://core.telegram.org/mtproto/description).
|
||||
|
||||
Note that this class is not thread-safe, and calling send/receive
|
||||
from two or more threads at the same time is undefined behaviour.
|
||||
Rationale: a new connection should be spawned to send/receive requests
|
||||
in parallel, so thread-safety (hence locking) isn't needed.
|
||||
"""
|
||||
|
||||
def __init__(self, connection, session):
|
||||
def __init__(self, session, connection):
|
||||
"""Creates a new MtProtoSender configured to send messages through
|
||||
'connection' and using the parameters from 'session'.
|
||||
"""
|
||||
self.connection = connection
|
||||
self.session = session
|
||||
self.connection = connection
|
||||
self._logger = logging.getLogger(__name__)
|
||||
|
||||
self._need_confirmation = [] # Message IDs that need confirmation
|
||||
self._pending_receive = [] # Requests sent waiting to be received
|
||||
# Message IDs that need confirmation
|
||||
self._need_confirmation = []
|
||||
|
||||
# Sending and receiving are independent, but two threads cannot
|
||||
# send or receive at the same time no matter what.
|
||||
self._send_lock = RLock()
|
||||
self._recv_lock = RLock()
|
||||
|
||||
# Used when logging out, the only request that seems to use 'ack'
|
||||
# TODO There might be a better way to handle msgs_ack requests
|
||||
self.logging_out = False
|
||||
# Requests (as msg_id: Message) sent waiting to be received
|
||||
self._pending_receive = {}
|
||||
|
||||
def connect(self):
|
||||
"""Connects to the server"""
|
||||
|
@ -47,33 +51,39 @@ class MtProtoSender:
|
|||
def disconnect(self):
|
||||
"""Disconnects from the server"""
|
||||
self.connection.close()
|
||||
self._need_confirmation.clear()
|
||||
self._clear_all_pending()
|
||||
|
||||
def clone(self):
|
||||
"""Creates a copy of this MtProtoSender as a new connection"""
|
||||
return MtProtoSender(self.session, self.connection.clone())
|
||||
|
||||
# region Send and receive
|
||||
|
||||
def send(self, request):
|
||||
def send(self, *requests):
|
||||
"""Sends the specified MTProtoRequest, previously sending any message
|
||||
which needed confirmation."""
|
||||
|
||||
# If any message needs confirmation send an AckRequest first
|
||||
self._send_acknowledges()
|
||||
|
||||
# Finally send our packed request
|
||||
with BinaryWriter() as writer:
|
||||
request.on_send(writer)
|
||||
self._send_packet(writer.get_bytes(), request)
|
||||
self._pending_receive.append(request)
|
||||
# Finally send our packed request(s)
|
||||
messages = [TLMessage(self.session, r) for r in requests]
|
||||
self._pending_receive.update({m.msg_id: m for m in messages})
|
||||
|
||||
# And update the saved session
|
||||
self.session.save()
|
||||
if len(messages) == 1:
|
||||
message = messages[0]
|
||||
else:
|
||||
message = TLMessage(self.session, MessageContainer(messages))
|
||||
|
||||
self._send_message(message)
|
||||
|
||||
def _send_acknowledges(self):
|
||||
"""Sends a messages acknowledge for all those who _need_confirmation"""
|
||||
if self._need_confirmation:
|
||||
msgs_ack = MsgsAck(self._need_confirmation)
|
||||
with BinaryWriter() as writer:
|
||||
msgs_ack.on_send(writer)
|
||||
self._send_packet(writer.get_bytes(), msgs_ack)
|
||||
|
||||
self._send_message(
|
||||
TLMessage(self.session, MsgsAck(self._need_confirmation))
|
||||
)
|
||||
del self._need_confirmation[:]
|
||||
|
||||
def receive(self, update_state):
|
||||
|
@ -86,21 +96,18 @@ class MtProtoSender:
|
|||
Any unhandled object (likely updates) will be passed to
|
||||
update_state.process(TLObject).
|
||||
"""
|
||||
with self._recv_lock:
|
||||
try:
|
||||
body = self.connection.recv()
|
||||
except (BufferError, InvalidChecksumError):
|
||||
# TODO BufferError, we should spot the cause...
|
||||
# "No more bytes left"; something wrong happened, clear
|
||||
# everything to be on the safe side, or:
|
||||
#
|
||||
# "This packet should be skipped"; since this may have
|
||||
# been a result for a request, invalidate every request
|
||||
# and just re-invoke them to avoid problems
|
||||
for r in self._pending_receive:
|
||||
r.confirm_received.set()
|
||||
self._pending_receive.clear()
|
||||
return
|
||||
try:
|
||||
body = self.connection.recv()
|
||||
except (BufferError, InvalidChecksumError):
|
||||
# TODO BufferError, we should spot the cause...
|
||||
# "No more bytes left"; something wrong happened, clear
|
||||
# everything to be on the safe side, or:
|
||||
#
|
||||
# "This packet should be skipped"; since this may have
|
||||
# been a result for a request, invalidate every request
|
||||
# and just re-invoke them to avoid problems
|
||||
self._clear_all_pending()
|
||||
return
|
||||
|
||||
message, remote_msg_id, remote_seq = self._decode_msg(body)
|
||||
with BinaryReader(message) as reader:
|
||||
|
@ -110,36 +117,20 @@ class MtProtoSender:
|
|||
|
||||
# region Low level processing
|
||||
|
||||
def _send_packet(self, packet, request):
|
||||
"""Sends the given packet bytes with the additional
|
||||
information of the original request.
|
||||
"""
|
||||
request.request_msg_id = self.session.get_new_msg_id()
|
||||
def _send_message(self, message):
|
||||
"""Sends the given Message(TLObject) encrypted through the network"""
|
||||
|
||||
# First calculate plain_text to encrypt it
|
||||
with BinaryWriter() as plain_writer:
|
||||
plain_writer.write_long(self.session.salt, signed=False)
|
||||
plain_writer.write_long(self.session.id, signed=False)
|
||||
plain_writer.write_long(request.request_msg_id)
|
||||
plain_writer.write_int(
|
||||
self.session.generate_sequence(request.content_related))
|
||||
plain_text = \
|
||||
struct.pack('<QQ', self.session.salt, self.session.id) \
|
||||
+ message.to_bytes()
|
||||
|
||||
plain_writer.write_int(len(packet))
|
||||
plain_writer.write(packet)
|
||||
msg_key = utils.calc_msg_key(plain_text)
|
||||
key_id = struct.pack('<Q', self.session.auth_key.key_id)
|
||||
key, iv = utils.calc_key(self.session.auth_key.key, msg_key, True)
|
||||
cipher_text = AES.encrypt_ige(plain_text, key, iv)
|
||||
|
||||
msg_key = utils.calc_msg_key(plain_writer.get_bytes())
|
||||
|
||||
key, iv = utils.calc_key(self.session.auth_key.key, msg_key, True)
|
||||
cipher_text = AES.encrypt_ige(plain_writer.get_bytes(), key, iv)
|
||||
|
||||
# And then finally send the encrypted packet
|
||||
with BinaryWriter() as cipher_writer:
|
||||
cipher_writer.write_long(
|
||||
self.session.auth_key.key_id, signed=False)
|
||||
cipher_writer.write(msg_key)
|
||||
cipher_writer.write(cipher_text)
|
||||
with self._send_lock:
|
||||
self.connection.send(cipher_writer.get_bytes())
|
||||
result = key_id + msg_key + cipher_text
|
||||
self.connection.send(result)
|
||||
|
||||
def _decode_msg(self, body):
|
||||
"""Decodes an received encrypted message body bytes"""
|
||||
|
@ -149,7 +140,10 @@ class MtProtoSender:
|
|||
|
||||
with BinaryReader(body) as reader:
|
||||
if len(body) < 8:
|
||||
raise BufferError("Can't decode packet ({})".format(body))
|
||||
if body == b'l\xfe\xff\xff':
|
||||
raise BrokenAuthKeyError()
|
||||
else:
|
||||
raise BufferError("Can't decode packet ({})".format(body))
|
||||
|
||||
# TODO Check for both auth key ID and msg_key correctness
|
||||
reader.read_long() # remote_auth_key_id
|
||||
|
@ -204,14 +198,15 @@ class MtProtoSender:
|
|||
# msgs_ack, it may handle the request we wanted
|
||||
if code == 0x62d6b459:
|
||||
ack = reader.tgread_object()
|
||||
for r in self._pending_receive:
|
||||
if r.request_msg_id in ack.msg_ids:
|
||||
self._logger.debug('Ack found for the a request')
|
||||
|
||||
if self.logging_out:
|
||||
self._logger.debug('Message ack confirmed a request')
|
||||
self._pending_receive.remove(r)
|
||||
r.confirm_received.set()
|
||||
# Ignore every ack request *unless* when logging out, when it's
|
||||
# when it seems to only make sense. We also need to set a non-None
|
||||
# result since Telegram doesn't send the response for these.
|
||||
for msg_id in ack.msg_ids:
|
||||
r = self._pop_request_of_type(msg_id, LogOutRequest)
|
||||
if r:
|
||||
r.result = True # Telegram won't send this value
|
||||
r.confirm_received()
|
||||
self._logger.debug('Message ack confirmed', r)
|
||||
|
||||
return True
|
||||
|
||||
|
@ -237,13 +232,26 @@ class MtProtoSender:
|
|||
|
||||
# region Message handling
|
||||
|
||||
def _pop_request(self, request_msg_id):
|
||||
"""Pops a pending request from self._pending_receive, or
|
||||
returns None if it's not found
|
||||
def _pop_request(self, msg_id):
|
||||
"""Pops a pending REQUEST from self._pending_receive, or
|
||||
returns None if it's not found.
|
||||
"""
|
||||
for i in range(len(self._pending_receive)):
|
||||
if self._pending_receive[i].request_msg_id == request_msg_id:
|
||||
return self._pending_receive.pop(i)
|
||||
message = self._pending_receive.pop(msg_id, None)
|
||||
if message:
|
||||
return message.request
|
||||
|
||||
def _pop_request_of_type(self, msg_id, t):
|
||||
"""Pops a pending REQUEST from self._pending_receive if it matches
|
||||
the given type, or returns None if it's not found/doesn't match.
|
||||
"""
|
||||
message = self._pending_receive.get(msg_id, None)
|
||||
if isinstance(message.request, t):
|
||||
return self._pending_receive.pop(msg_id).request
|
||||
|
||||
def _clear_all_pending(self):
|
||||
for r in self._pending_receive.values():
|
||||
r.confirm_received.set()
|
||||
self._pending_receive.clear()
|
||||
|
||||
def _handle_pong(self, msg_id, sequence, reader):
|
||||
self._logger.debug('Handling pong')
|
||||
|
@ -259,22 +267,17 @@ class MtProtoSender:
|
|||
|
||||
def _handle_container(self, msg_id, sequence, reader, state):
|
||||
self._logger.debug('Handling container')
|
||||
reader.read_int(signed=False) # code
|
||||
size = reader.read_int()
|
||||
for _ in range(size):
|
||||
inner_msg_id = reader.read_long()
|
||||
reader.read_int() # inner_sequence
|
||||
inner_length = reader.read_int()
|
||||
for inner_msg_id, _, inner_len in MessageContainer.iter_read(reader):
|
||||
begin_position = reader.tell_position()
|
||||
|
||||
# Note that this code is IMPORTANT for skipping RPC results of
|
||||
# lost requests (i.e., ones from the previous connection session)
|
||||
try:
|
||||
if not self._process_msg(inner_msg_id, sequence, reader, state):
|
||||
reader.set_position(begin_position + inner_length)
|
||||
reader.set_position(begin_position + inner_len)
|
||||
except:
|
||||
# If any error is raised, something went wrong; skip the packet
|
||||
reader.set_position(begin_position + inner_length)
|
||||
reader.set_position(begin_position + inner_len)
|
||||
raise
|
||||
|
||||
return True
|
||||
|
@ -306,7 +309,6 @@ class MtProtoSender:
|
|||
# sent msg_id too low or too high (respectively).
|
||||
# Use the current msg_id to determine the right time offset.
|
||||
self.session.update_time_offset(correct_msg_id=msg_id)
|
||||
self.session.save()
|
||||
self._logger.debug('Read Bad Message error: ' + str(error))
|
||||
self._logger.debug('Attempting to use the correct time offset.')
|
||||
return True
|
||||
|
@ -334,7 +336,7 @@ class MtProtoSender:
|
|||
if self.session.report_errors and request:
|
||||
error = rpc_message_to_error(
|
||||
reader.read_int(), reader.tgread_string(),
|
||||
report_method=type(request).constructor_id
|
||||
report_method=type(request).CONSTRUCTOR_ID
|
||||
)
|
||||
else:
|
||||
error = rpc_message_to_error(
|
||||
|
@ -372,11 +374,7 @@ class MtProtoSender:
|
|||
|
||||
def _handle_gzip_packed(self, msg_id, sequence, reader, state):
|
||||
self._logger.debug('Handling gzip packed data')
|
||||
reader.read_int(signed=False) # code
|
||||
packed_data = reader.tgread_bytes()
|
||||
unpacked_data = gzip.decompress(packed_data)
|
||||
|
||||
with BinaryReader(unpacked_data) as compressed_reader:
|
||||
with BinaryReader(GzipPacked.read(reader)) as compressed_reader:
|
||||
return self._process_msg(msg_id, sequence, compressed_reader, state)
|
||||
|
||||
# endregion
|
||||
|
|
|
@ -1,20 +1,24 @@
|
|||
import logging
|
||||
from datetime import timedelta
|
||||
import os
|
||||
import threading
|
||||
from datetime import timedelta, datetime
|
||||
from hashlib import md5
|
||||
from io import BytesIO
|
||||
from os import path
|
||||
from threading import Lock
|
||||
from time import sleep
|
||||
|
||||
from . import helpers as utils
|
||||
from .crypto import rsa, CdnDecrypter
|
||||
from .errors import (
|
||||
RPCError, BrokenAuthKeyError,
|
||||
FloodWaitError, FileMigrateError, TypeNotFoundError
|
||||
RPCError, BrokenAuthKeyError, ServerError,
|
||||
FloodWaitError, FileMigrateError, TypeNotFoundError,
|
||||
UnauthorizedError, PhoneMigrateError, NetworkMigrateError, UserMigrateError
|
||||
)
|
||||
from .network import authenticator, MtProtoSender, Connection, ConnectionMode
|
||||
from .tl import TLObject, Session
|
||||
from .tl.all_tlobjects import LAYER
|
||||
from .tl.functions import (
|
||||
InitConnectionRequest, InvokeWithLayerRequest
|
||||
InitConnectionRequest, InvokeWithLayerRequest, PingRequest
|
||||
)
|
||||
from .tl.functions.auth import (
|
||||
ImportAuthorizationRequest, ExportAuthorizationRequest
|
||||
|
@ -22,6 +26,7 @@ from .tl.functions.auth import (
|
|||
from .tl.functions.help import (
|
||||
GetCdnConfigRequest, GetConfigRequest
|
||||
)
|
||||
from .tl.functions.updates import GetStateRequest
|
||||
from .tl.functions.upload import (
|
||||
GetFileRequest, SaveBigFilePartRequest, SaveFilePartRequest
|
||||
)
|
||||
|
@ -52,7 +57,7 @@ class TelegramBareClient:
|
|||
"""
|
||||
|
||||
# Current TelegramClient version
|
||||
__version__ = '0.13.4'
|
||||
__version__ = '0.14.2'
|
||||
|
||||
# TODO Make this thread-safe, all connections share the same DC
|
||||
_dc_options = None
|
||||
|
@ -62,63 +67,124 @@ class TelegramBareClient:
|
|||
def __init__(self, session, api_id, api_hash,
|
||||
connection_mode=ConnectionMode.TCP_FULL,
|
||||
proxy=None,
|
||||
process_updates=False,
|
||||
timeout=timedelta(seconds=5)):
|
||||
"""Initializes the Telegram client with the specified API ID and Hash.
|
||||
Session must always be a Session instance, and an optional proxy
|
||||
can also be specified to be used on the connection.
|
||||
"""
|
||||
update_workers=None,
|
||||
spawn_read_thread=False,
|
||||
timeout=timedelta(seconds=5),
|
||||
**kwargs):
|
||||
"""Refer to TelegramClient.__init__ for docs on this method"""
|
||||
if not api_id or not api_hash:
|
||||
raise PermissionError(
|
||||
"Your API ID or Hash cannot be empty or None. "
|
||||
"Refer to Telethon's README.rst for more information.")
|
||||
|
||||
# Determine what session object we have
|
||||
if isinstance(session, str) or session is None:
|
||||
session = Session.try_load_or_create_new(session)
|
||||
elif not isinstance(session, Session):
|
||||
raise ValueError(
|
||||
'The given session must be a str or a Session instance.'
|
||||
)
|
||||
|
||||
self.session = session
|
||||
self.api_id = int(api_id)
|
||||
self.api_hash = api_hash
|
||||
if self.api_id < 20: # official apps must use obfuscated
|
||||
self._connection_mode = ConnectionMode.TCP_OBFUSCATED
|
||||
else:
|
||||
self._connection_mode = connection_mode
|
||||
self.proxy = proxy
|
||||
self._timeout = timeout
|
||||
connection_mode = ConnectionMode.TCP_OBFUSCATED
|
||||
|
||||
# This is the main sender, which will be used from the thread
|
||||
# that calls .connect(). Every other thread will spawn a new
|
||||
# temporary connection. The connection on this one is always
|
||||
# kept open so Telegram can send us updates.
|
||||
self._sender = MtProtoSender(self.session, Connection(
|
||||
self.session.server_address, self.session.port,
|
||||
mode=connection_mode, proxy=proxy, timeout=timeout
|
||||
))
|
||||
|
||||
self._logger = logging.getLogger(__name__)
|
||||
|
||||
# Cache "exported" senders 'dc_id: TelegramBareClient' and
|
||||
# their corresponding sessions not to recreate them all
|
||||
# the time since it's a (somewhat expensive) process.
|
||||
self._cached_clients = {}
|
||||
# Two threads may be calling reconnect() when the connection is lost,
|
||||
# we only want one to actually perform the reconnection.
|
||||
self._reconnect_lock = Lock()
|
||||
|
||||
# Cache "exported" sessions as 'dc_id: Session' not to recreate
|
||||
# them all the time since generating a new key is a relatively
|
||||
# expensive operation.
|
||||
self._exported_sessions = {}
|
||||
|
||||
# This member will process updates if enabled.
|
||||
# One may change self.updates.enabled at any later point.
|
||||
self.updates = UpdateState(process_updates)
|
||||
self.updates = UpdateState(workers=update_workers)
|
||||
|
||||
# These will be set later
|
||||
self._sender = None
|
||||
# Used on connection - the user may modify these and reconnect
|
||||
kwargs['app_version'] = kwargs.get('app_version', self.__version__)
|
||||
for name, value in kwargs.items():
|
||||
if not hasattr(self.session, name):
|
||||
raise ValueError('Unknown named parameter', name)
|
||||
setattr(self.session, name, value)
|
||||
|
||||
# Despite the state of the real connection, keep track of whether
|
||||
# the user has explicitly called .connect() or .disconnect() here.
|
||||
# This information is required by the read thread, who will be the
|
||||
# one attempting to reconnect on the background *while* the user
|
||||
# doesn't explicitly call .disconnect(), thus telling it to stop
|
||||
# retrying. The main thread, knowing there is a background thread
|
||||
# attempting reconnection as soon as it happens, will just sleep.
|
||||
self._user_connected = False
|
||||
|
||||
# Save whether the user is authorized here (a.k.a. logged in)
|
||||
self._authorized = False
|
||||
|
||||
# Uploaded files cache so subsequent calls are instant
|
||||
self._upload_cache = {}
|
||||
|
||||
# Constantly read for results and updates from within the main client,
|
||||
# if the user has left enabled such option.
|
||||
self._spawn_read_thread = spawn_read_thread
|
||||
self._recv_thread = None
|
||||
|
||||
# Identifier of the main thread (the one that called .connect()).
|
||||
# This will be used to create new connections from any other thread,
|
||||
# so that requests can be sent in parallel.
|
||||
self._main_thread_ident = None
|
||||
|
||||
# Default PingRequest delay
|
||||
self._last_ping = datetime.now()
|
||||
self._ping_delay = timedelta(minutes=1)
|
||||
|
||||
# endregion
|
||||
|
||||
# region Connecting
|
||||
|
||||
def connect(self, exported_auth=None):
|
||||
def connect(self, _exported_auth=None, _sync_updates=True, _cdn=False):
|
||||
"""Connects to the Telegram servers, executing authentication if
|
||||
required. Note that authenticating to the Telegram servers is
|
||||
not the same as authenticating the desired user itself, which
|
||||
may require a call (or several) to 'sign_in' for the first time.
|
||||
|
||||
If 'exported_auth' is not None, it will be used instead to
|
||||
determine the authorization key for the current session.
|
||||
"""
|
||||
if self.is_connected():
|
||||
return True
|
||||
Note that the optional parameters are meant for internal use.
|
||||
|
||||
connection = Connection(
|
||||
self.session.server_address, self.session.port,
|
||||
mode=self._connection_mode, proxy=self.proxy, timeout=self._timeout
|
||||
)
|
||||
If '_exported_auth' is not None, it will be used instead to
|
||||
determine the authorization key for the current session.
|
||||
|
||||
If '_sync_updates', sync_updates() will be called and a
|
||||
second thread will be started if necessary. Note that this
|
||||
will FAIL if the client is not connected to the user's
|
||||
native data center, raising a "UserMigrateError", and
|
||||
calling .disconnect() in the process.
|
||||
|
||||
If '_cdn' is False, methods that are not allowed on such data
|
||||
centers won't be invoked.
|
||||
"""
|
||||
self._main_thread_ident = threading.get_ident()
|
||||
|
||||
try:
|
||||
self._sender.connect()
|
||||
if not self.session.auth_key:
|
||||
# New key, we need to tell the server we're going to use
|
||||
# the latest layer
|
||||
try:
|
||||
self.session.auth_key, self.session.time_offset = \
|
||||
authenticator.do_authentication(connection)
|
||||
authenticator.do_authentication(self._sender.connection)
|
||||
except BrokenAuthKeyError:
|
||||
return False
|
||||
|
||||
|
@ -128,34 +194,47 @@ class TelegramBareClient:
|
|||
else:
|
||||
init_connection = self.session.layer != LAYER
|
||||
|
||||
self._sender = MtProtoSender(connection, self.session)
|
||||
self._sender.connect()
|
||||
|
||||
if init_connection:
|
||||
if exported_auth is not None:
|
||||
if _exported_auth is not None:
|
||||
self._init_connection(ImportAuthorizationRequest(
|
||||
exported_auth.id, exported_auth.bytes
|
||||
_exported_auth.id, _exported_auth.bytes
|
||||
))
|
||||
else:
|
||||
elif not _cdn:
|
||||
TelegramBareClient._dc_options = \
|
||||
self._init_connection(GetConfigRequest()).dc_options
|
||||
|
||||
elif exported_auth is not None:
|
||||
elif _exported_auth is not None:
|
||||
self(ImportAuthorizationRequest(
|
||||
exported_auth.id, exported_auth.bytes
|
||||
_exported_auth.id, _exported_auth.bytes
|
||||
))
|
||||
|
||||
if TelegramBareClient._dc_options is None:
|
||||
if TelegramBareClient._dc_options is None and not _cdn:
|
||||
TelegramBareClient._dc_options = \
|
||||
self(GetConfigRequest()).dc_options
|
||||
|
||||
# Connection was successful! Try syncing the update state
|
||||
# UNLESS '_sync_updates' is False (we probably are in
|
||||
# another data center and this would raise UserMigrateError)
|
||||
# to also assert whether the user is logged in or not.
|
||||
self._user_connected = True
|
||||
if _sync_updates and not _cdn:
|
||||
try:
|
||||
self.sync_updates()
|
||||
self._set_connected_and_authorized()
|
||||
except UnauthorizedError:
|
||||
self._authorized = False
|
||||
|
||||
return True
|
||||
|
||||
except TypeNotFoundError as e:
|
||||
# This is fine, probably layer migration
|
||||
self._logger.debug('Found invalid item, probably migrating', e)
|
||||
self.disconnect()
|
||||
return self.connect(exported_auth=exported_auth)
|
||||
return self.connect(
|
||||
_exported_auth=_exported_auth,
|
||||
_sync_updates=_sync_updates,
|
||||
_cdn=_cdn
|
||||
)
|
||||
|
||||
except (RPCError, ConnectionError) as error:
|
||||
# Probably errors from the previous session, ignore them
|
||||
|
@ -166,7 +245,7 @@ class TelegramBareClient:
|
|||
return False
|
||||
|
||||
def is_connected(self):
|
||||
return self._sender is not None and self._sender.is_connected()
|
||||
return self._sender.is_connected()
|
||||
|
||||
def _init_connection(self, query=None):
|
||||
result = self(InvokeWithLayerRequest(LAYER, InitConnectionRequest(
|
||||
|
@ -184,31 +263,54 @@ class TelegramBareClient:
|
|||
return result
|
||||
|
||||
def disconnect(self):
|
||||
"""Disconnects from the Telegram server"""
|
||||
if self._sender:
|
||||
self._sender.disconnect()
|
||||
self._sender = None
|
||||
"""Disconnects from the Telegram server
|
||||
and stops all the spawned threads"""
|
||||
self._user_connected = False
|
||||
self._recv_thread = None
|
||||
|
||||
def reconnect(self, new_dc=None):
|
||||
"""Disconnects and connects again (effectively reconnecting).
|
||||
# This will trigger a "ConnectionResetError", for subsequent calls
|
||||
# to read or send (from another thread) and usually, the background
|
||||
# thread would try restarting the connection but since the
|
||||
# ._recv_thread = None, it knows it doesn't have to.
|
||||
self._sender.disconnect()
|
||||
|
||||
If 'new_dc' is not None, the current authorization key is
|
||||
removed, the DC used is switched, and a new connection is made.
|
||||
# TODO Shall we clear the _exported_sessions, or may be reused?
|
||||
pass
|
||||
|
||||
def _reconnect(self, new_dc=None):
|
||||
"""If 'new_dc' is not set, only a call to .connect() will be made
|
||||
since it's assumed that the connection has been lost and the
|
||||
library is reconnecting.
|
||||
|
||||
If 'new_dc' is set, the client is first disconnected from the
|
||||
current data center, clears the auth key for the old DC, and
|
||||
connects to the new data center.
|
||||
"""
|
||||
self.disconnect()
|
||||
|
||||
if new_dc is not None:
|
||||
if new_dc is None:
|
||||
# Assume we are disconnected due to some error, so connect again
|
||||
with self._reconnect_lock:
|
||||
# Another thread may have connected again, so check that first
|
||||
if not self.is_connected():
|
||||
return self.connect()
|
||||
else:
|
||||
return True
|
||||
else:
|
||||
self.disconnect()
|
||||
self.session.auth_key = None # Force creating new auth_key
|
||||
dc = self._get_dc(new_dc)
|
||||
self.session.server_address = dc.ip_address
|
||||
self.session.port = dc.port
|
||||
ip = dc.ip_address
|
||||
self._sender.connection.ip = self.session.server_address = ip
|
||||
self._sender.connection.port = self.session.port = dc.port
|
||||
self.session.save()
|
||||
|
||||
self.connect()
|
||||
return self.connect()
|
||||
|
||||
# endregion
|
||||
|
||||
# region Working with different Data Centers
|
||||
# region Working with different connections/Data Centers
|
||||
|
||||
def _on_read_thread(self):
|
||||
return self._recv_thread is not None and \
|
||||
threading.get_ident() == self._recv_thread.ident
|
||||
|
||||
def _get_dc(self, dc_id, ipv6=False, cdn=False):
|
||||
"""Gets the Data Center (DC) associated to 'dc_id'"""
|
||||
|
@ -235,30 +337,23 @@ class TelegramBareClient:
|
|||
TelegramBareClient._dc_options = self(GetConfigRequest()).dc_options
|
||||
return self._get_dc(dc_id, ipv6=ipv6, cdn=cdn)
|
||||
|
||||
def _get_exported_client(self, dc_id,
|
||||
init_connection=False,
|
||||
bypass_cache=False):
|
||||
"""Gets a cached exported TelegramBareClient for the desired DC.
|
||||
def _get_exported_client(self, dc_id):
|
||||
"""Creates and connects a new TelegramBareClient for the desired DC.
|
||||
|
||||
If it's the first time retrieving the TelegramBareClient, the
|
||||
current authorization is exported to the new DC so that
|
||||
it can be used there, and the connection is initialized.
|
||||
|
||||
If after using the sender a ConnectionResetError is raised,
|
||||
this method should be called again with init_connection=True
|
||||
in order to perform the reconnection.
|
||||
|
||||
If bypass_cache is True, a new client will be exported and
|
||||
it will not be cached.
|
||||
If it's the first time calling the method with a given dc_id,
|
||||
a new session will be first created, and its auth key generated.
|
||||
Exporting/Importing the authorization will also be done so that
|
||||
the auth is bound with the key.
|
||||
"""
|
||||
# Thanks badoualy/kotlogram on /telegram/api/DefaultTelegramClient.kt
|
||||
# for clearly showing how to export the authorization! ^^
|
||||
client = self._cached_clients.get(dc_id)
|
||||
if client and not bypass_cache:
|
||||
if init_connection:
|
||||
client.reconnect()
|
||||
return client
|
||||
session = self._exported_sessions.get(dc_id)
|
||||
if session:
|
||||
export_auth = None # Already bound with the auth key
|
||||
else:
|
||||
# TODO Add a lock, don't allow two threads to create an auth key
|
||||
# (when calling .connect() if there wasn't a previous session).
|
||||
# for the same data center.
|
||||
dc = self._get_dc(dc_id)
|
||||
|
||||
# Export the current authorization to the new DC.
|
||||
|
@ -272,80 +367,172 @@ class TelegramBareClient:
|
|||
session = Session(self.session)
|
||||
session.server_address = dc.ip_address
|
||||
session.port = dc.port
|
||||
client = TelegramBareClient(
|
||||
session, self.api_id, self.api_hash,
|
||||
timeout=self._timeout
|
||||
)
|
||||
client.connect(exported_auth=export_auth)
|
||||
self._exported_sessions[dc_id] = session
|
||||
|
||||
if not bypass_cache:
|
||||
# Don't go through this expensive process every time.
|
||||
self._cached_clients[dc_id] = client
|
||||
return client
|
||||
client = TelegramBareClient(
|
||||
session, self.api_id, self.api_hash,
|
||||
proxy=self._sender.connection.conn.proxy,
|
||||
timeout=self._sender.connection.get_timeout()
|
||||
)
|
||||
client.connect(_exported_auth=export_auth, _sync_updates=False)
|
||||
client._authorized = True # We exported the auth, so we got auth
|
||||
return client
|
||||
|
||||
def _get_cdn_client(self, cdn_redirect):
|
||||
"""Similar to ._get_exported_client, but for CDNs"""
|
||||
session = self._exported_sessions.get(cdn_redirect.dc_id)
|
||||
if not session:
|
||||
dc = self._get_dc(cdn_redirect.dc_id, cdn=True)
|
||||
session = Session(self.session)
|
||||
session.server_address = dc.ip_address
|
||||
session.port = dc.port
|
||||
self._exported_sessions[cdn_redirect.dc_id] = session
|
||||
|
||||
client = TelegramBareClient(
|
||||
session, self.api_id, self.api_hash,
|
||||
proxy=self._sender.connection.conn.proxy,
|
||||
timeout=self._sender.connection.get_timeout()
|
||||
)
|
||||
|
||||
# This will make use of the new RSA keys for this specific CDN.
|
||||
#
|
||||
# This relies on the fact that TelegramBareClient._dc_options is
|
||||
# static and it won't be called from this DC (it would fail).
|
||||
client.connect(_cdn=True) # Avoid invoking non-CDN specific methods
|
||||
client._authorized = self._authorized
|
||||
return client
|
||||
|
||||
# endregion
|
||||
|
||||
# region Invoking Telegram requests
|
||||
|
||||
def invoke(self, request, call_receive=True, retries=5):
|
||||
def invoke(self, *requests, retries=5):
|
||||
"""Invokes (sends) a MTProtoRequest and returns (receives) its result.
|
||||
|
||||
If 'updates' is not None, all read update object will be put
|
||||
in such list. Otherwise, update objects will be ignored.
|
||||
|
||||
If 'call_receive' is set to False, then there should be another
|
||||
thread calling to 'self._sender.receive()' running or this method
|
||||
will lock forever.
|
||||
The invoke will be retried up to 'retries' times before raising
|
||||
ValueError().
|
||||
"""
|
||||
if not isinstance(request, TLObject) and not request.content_related:
|
||||
if not all(isinstance(x, TLObject) and
|
||||
x.content_related for x in requests):
|
||||
raise ValueError('You can only invoke requests, not types!')
|
||||
|
||||
if not self._sender:
|
||||
raise ValueError('You must be connected to invoke requests!')
|
||||
# Determine the sender to be used (main or a new connection)
|
||||
on_main_thread = threading.get_ident() == self._main_thread_ident
|
||||
if on_main_thread or self._on_read_thread():
|
||||
sender = self._sender
|
||||
else:
|
||||
sender = self._sender.clone()
|
||||
sender.connect()
|
||||
|
||||
# We should call receive from this thread if there's no background
|
||||
# thread reading or if the server disconnected us and we're trying
|
||||
# to reconnect. This is because the read thread may either be
|
||||
# locked also trying to reconnect or we may be said thread already.
|
||||
call_receive = not on_main_thread or self._recv_thread is None \
|
||||
or self._reconnect_lock.locked()
|
||||
try:
|
||||
for _ in range(retries):
|
||||
result = self._invoke(sender, call_receive, *requests)
|
||||
if result:
|
||||
return result
|
||||
|
||||
if retries <= 0:
|
||||
raise ValueError('Number of retries reached 0.')
|
||||
finally:
|
||||
if sender != self._sender:
|
||||
sender.disconnect() # Close temporary connections
|
||||
|
||||
def _invoke(self, sender, call_receive, *requests):
|
||||
try:
|
||||
# Ensure that we start with no previous errors (i.e. resending)
|
||||
request.confirm_received.clear()
|
||||
request.rpc_error = None
|
||||
for x in requests:
|
||||
x.confirm_received.clear()
|
||||
x.rpc_error = None
|
||||
|
||||
sender.send(*requests)
|
||||
|
||||
self._sender.send(request)
|
||||
if not call_receive:
|
||||
# TODO This will be slightly troublesome if we allow
|
||||
# switching between constant read or not on the fly.
|
||||
# Must also watch out for calling .read() from two places,
|
||||
# in which case a Lock would be required for .receive().
|
||||
request.confirm_received.wait() # TODO Socket's timeout here?
|
||||
for x in requests:
|
||||
x.confirm_received.wait(
|
||||
sender.connection.get_timeout()
|
||||
)
|
||||
else:
|
||||
while not request.confirm_received.is_set():
|
||||
self._sender.receive(update_state=self.updates)
|
||||
while not all(x.confirm_received.is_set() for x in requests):
|
||||
sender.receive(update_state=self.updates)
|
||||
|
||||
except TimeoutError:
|
||||
pass # We will just retry
|
||||
|
||||
except ConnectionResetError:
|
||||
if not self._authorized or self._reconnect_lock.locked():
|
||||
# Only attempt reconnecting if we're authorized and not
|
||||
# reconnecting already.
|
||||
raise
|
||||
|
||||
self._logger.debug('Server disconnected us. Reconnecting and '
|
||||
'resending request...')
|
||||
self.reconnect()
|
||||
|
||||
if sender != self._sender:
|
||||
# TODO Try reconnecting forever too?
|
||||
sender.connect()
|
||||
else:
|
||||
while self._user_connected and not self._reconnect():
|
||||
sleep(0.1) # Retry forever until we can send the request
|
||||
|
||||
finally:
|
||||
if sender != self._sender:
|
||||
sender.disconnect()
|
||||
|
||||
try:
|
||||
raise next(x.rpc_error for x in requests if x.rpc_error)
|
||||
except StopIteration:
|
||||
if any(x.result is None for x in requests):
|
||||
# "A container may only be accepted or
|
||||
# rejected by the other party as a whole."
|
||||
return None
|
||||
elif len(requests) == 1:
|
||||
return requests[0].result
|
||||
else:
|
||||
return [x.result for x in requests]
|
||||
|
||||
except (PhoneMigrateError, NetworkMigrateError,
|
||||
UserMigrateError) as e:
|
||||
self._logger.debug(
|
||||
'DC error when invoking request, '
|
||||
'attempting to reconnect at DC {}'.format(e.new_dc)
|
||||
)
|
||||
|
||||
# TODO What happens with the background thread here?
|
||||
# For normal use cases, this won't happen, because this will only
|
||||
# be on the very first connection (not authorized, not running),
|
||||
# but may be an issue for people who actually travel?
|
||||
self._reconnect(new_dc=e.new_dc)
|
||||
return self._invoke(sender, call_receive, *requests)
|
||||
|
||||
except ServerError as e:
|
||||
# Telegram is having some issues, just retry
|
||||
self._logger.debug(
|
||||
'[ERROR] Telegram is having some internal issues', e
|
||||
)
|
||||
|
||||
except FloodWaitError:
|
||||
sender.disconnect()
|
||||
self.disconnect()
|
||||
raise
|
||||
|
||||
if request.rpc_error:
|
||||
raise request.rpc_error
|
||||
if request.result is None:
|
||||
return self.invoke(
|
||||
request, call_receive=call_receive, retries=(retries - 1)
|
||||
)
|
||||
else:
|
||||
return request.result
|
||||
|
||||
# Let people use client(SomeRequest()) instead client.invoke(...)
|
||||
__call__ = invoke
|
||||
|
||||
# Some really basic functionality
|
||||
|
||||
def is_user_authorized(self):
|
||||
"""Has the user been authorized yet
|
||||
(code request sent and confirmed)?"""
|
||||
return self._authorized
|
||||
|
||||
# endregion
|
||||
|
||||
# region Uploading media
|
||||
|
@ -371,10 +558,10 @@ class TelegramBareClient:
|
|||
|
||||
Default values for the optional parameters if left as None are:
|
||||
part_size_kb = get_appropriated_part_size(file_size)
|
||||
file_name = path.basename(file_path)
|
||||
file_name = os.path.basename(file_path)
|
||||
"""
|
||||
if isinstance(file, str):
|
||||
file_size = path.getsize(file)
|
||||
file_size = os.path.getsize(file)
|
||||
elif isinstance(file, bytes):
|
||||
file_size = len(file)
|
||||
else:
|
||||
|
@ -430,7 +617,7 @@ class TelegramBareClient:
|
|||
# Set a default file name if None was specified
|
||||
if not file_name:
|
||||
if isinstance(file, str):
|
||||
file_name = path.basename(file)
|
||||
file_name = os.path.basename(file)
|
||||
else:
|
||||
file_name = str(file_id)
|
||||
|
||||
|
@ -499,7 +686,7 @@ class TelegramBareClient:
|
|||
if isinstance(result, FileCdnRedirect):
|
||||
cdn_decrypter, result = \
|
||||
CdnDecrypter.prepare_decrypter(
|
||||
client, TelegramBareClient, result
|
||||
client, self._get_cdn_client(result), result
|
||||
)
|
||||
|
||||
except FileMigrateError as e:
|
||||
|
@ -518,6 +705,9 @@ class TelegramBareClient:
|
|||
if progress_callback:
|
||||
progress_callback(f.tell(), file_size)
|
||||
finally:
|
||||
if client != self:
|
||||
client.disconnect()
|
||||
|
||||
if cdn_decrypter:
|
||||
try:
|
||||
cdn_decrypter.client.disconnect()
|
||||
|
@ -527,3 +717,80 @@ class TelegramBareClient:
|
|||
f.close()
|
||||
|
||||
# endregion
|
||||
|
||||
# region Updates handling
|
||||
|
||||
def sync_updates(self):
|
||||
"""Synchronizes self.updates to their initial state. Will be
|
||||
called automatically on connection if self.updates.enabled = True,
|
||||
otherwise it should be called manually after enabling updates.
|
||||
"""
|
||||
self.updates.process(self(GetStateRequest()))
|
||||
|
||||
def add_update_handler(self, handler):
|
||||
"""Adds an update handler (a function which takes a TLObject,
|
||||
an update, as its parameter) and listens for updates"""
|
||||
sync = not self.updates.handlers
|
||||
self.updates.handlers.append(handler)
|
||||
if sync:
|
||||
self.sync_updates()
|
||||
|
||||
def remove_update_handler(self, handler):
|
||||
self.updates.handlers.remove(handler)
|
||||
|
||||
def list_update_handlers(self):
|
||||
return self.updates.handlers[:]
|
||||
|
||||
# endregion
|
||||
|
||||
# Constant read
|
||||
|
||||
def _set_connected_and_authorized(self):
|
||||
self._authorized = True
|
||||
if self._spawn_read_thread and self._recv_thread is None:
|
||||
self._recv_thread = threading.Thread(
|
||||
name='ReadThread', daemon=True,
|
||||
target=self._recv_thread_impl
|
||||
)
|
||||
self._recv_thread.start()
|
||||
|
||||
# By using this approach, another thread will be
|
||||
# created and started upon connection to constantly read
|
||||
# from the other end. Otherwise, manual calls to .receive()
|
||||
# must be performed. The MtProtoSender cannot be connected,
|
||||
# or an error will be thrown.
|
||||
#
|
||||
# This way, sending and receiving will be completely independent.
|
||||
def _recv_thread_impl(self):
|
||||
while self._user_connected:
|
||||
try:
|
||||
if datetime.now() > self._last_ping + self._ping_delay:
|
||||
self._sender.send(PingRequest(
|
||||
int.from_bytes(os.urandom(8), 'big', signed=True)
|
||||
))
|
||||
self._last_ping = datetime.now()
|
||||
|
||||
self._sender.receive(update_state=self.updates)
|
||||
except TimeoutError:
|
||||
# No problem.
|
||||
pass
|
||||
except ConnectionResetError:
|
||||
self._logger.debug('Server disconnected us. Reconnecting...')
|
||||
while self._user_connected and not self._reconnect():
|
||||
sleep(0.1) # Retry forever, this is instant messaging
|
||||
|
||||
except Exception as error:
|
||||
# Unknown exception, pass it to the main thread
|
||||
self._logger.debug(
|
||||
'[ERROR] Unknown error on the read thread, please report',
|
||||
error
|
||||
)
|
||||
# If something strange happens we don't want to enter an
|
||||
# infinite loop where all we do is raise an exception, so
|
||||
# add a little sleep to avoid the CPU usage going mad.
|
||||
sleep(0.1)
|
||||
break
|
||||
|
||||
self._recv_thread = None
|
||||
|
||||
# endregion
|
||||
|
|
|
@ -1,20 +1,21 @@
|
|||
import os
|
||||
import threading
|
||||
from datetime import datetime, timedelta
|
||||
from functools import lru_cache
|
||||
from mimetypes import guess_type
|
||||
from threading import Thread
|
||||
|
||||
try:
|
||||
import socks
|
||||
except ImportError:
|
||||
socks = None
|
||||
|
||||
from . import TelegramBareClient
|
||||
from . import helpers as utils
|
||||
from .errors import (
|
||||
RPCError, UnauthorizedError, InvalidParameterError, PhoneCodeEmptyError,
|
||||
PhoneMigrateError, NetworkMigrateError, UserMigrateError,
|
||||
PhoneCodeExpiredError, PhoneCodeHashEmptyError, PhoneCodeInvalidError
|
||||
)
|
||||
from .network import ConnectionMode
|
||||
from .tl import Session, TLObject
|
||||
from .tl.functions import PingRequest
|
||||
from .tl import TLObject
|
||||
from .tl.functions.account import (
|
||||
GetPasswordRequest
|
||||
)
|
||||
|
@ -29,9 +30,6 @@ from .tl.functions.messages import (
|
|||
GetDialogsRequest, GetHistoryRequest, ReadHistoryRequest, SendMediaRequest,
|
||||
SendMessageRequest
|
||||
)
|
||||
from .tl.functions.updates import (
|
||||
GetStateRequest
|
||||
)
|
||||
from .tl.functions.users import (
|
||||
GetUsersRequest
|
||||
)
|
||||
|
@ -43,6 +41,7 @@ from .tl.types import (
|
|||
InputUserSelf, UserProfilePhoto, ChatPhoto, UpdateMessageID,
|
||||
UpdateNewMessage, UpdateShortSentMessage
|
||||
)
|
||||
from .tl.types.messages import DialogsSlice
|
||||
from .utils import find_user_or_chat, get_extension
|
||||
|
||||
|
||||
|
@ -59,8 +58,9 @@ class TelegramClient(TelegramBareClient):
|
|||
def __init__(self, session, api_id, api_hash,
|
||||
connection_mode=ConnectionMode.TCP_FULL,
|
||||
proxy=None,
|
||||
process_updates=False,
|
||||
update_workers=None,
|
||||
timeout=timedelta(seconds=5),
|
||||
spawn_read_thread=True,
|
||||
**kwargs):
|
||||
"""Initializes the Telegram client with the specified API ID and Hash.
|
||||
|
||||
|
@ -73,15 +73,21 @@ class TelegramClient(TelegramBareClient):
|
|||
This will only affect how messages are sent over the network
|
||||
and how much processing is required before sending them.
|
||||
|
||||
If 'process_updates' is set to True, incoming updates will be
|
||||
processed and you must manually call 'self.updates.poll()' from
|
||||
another thread to retrieve the saved update objects, or your
|
||||
memory will fill with these. You may modify the value of
|
||||
'self.updates.polling' at any later point.
|
||||
The integer 'update_workers' represents depending on its value:
|
||||
is None: Updates will *not* be stored in memory.
|
||||
= 0: Another thread is responsible for calling self.updates.poll()
|
||||
> 0: 'update_workers' background threads will be spawned, any
|
||||
any of them will invoke all the self.updates.handlers.
|
||||
|
||||
Despite the value of 'process_updates', if you later call
|
||||
'.add_update_handler(...)', updates will also be processed
|
||||
and the update objects will be passed to the handlers you added.
|
||||
If 'spawn_read_thread', a background thread will be started once
|
||||
an authorized user has been logged in to Telegram to read items
|
||||
(such as updates and responses) from the network as soon as they
|
||||
occur, which will speed things up.
|
||||
|
||||
If you don't want to spawn any additional threads, pending updates
|
||||
will be read and processed accordingly after invoking a request
|
||||
and not immediately. This is useful if you don't care about updates
|
||||
at all and have set 'update_workers=None'.
|
||||
|
||||
If more named arguments are provided as **kwargs, they will be
|
||||
used to update the Session instance. Most common settings are:
|
||||
|
@ -92,210 +98,55 @@ class TelegramClient(TelegramBareClient):
|
|||
system_lang_code = lang_code
|
||||
report_errors = True
|
||||
"""
|
||||
if not api_id or not api_hash:
|
||||
raise PermissionError(
|
||||
"Your API ID or Hash cannot be empty or None. "
|
||||
"Refer to Telethon's README.rst for more information.")
|
||||
|
||||
# Determine what session object we have
|
||||
if isinstance(session, str) or session is None:
|
||||
session = Session.try_load_or_create_new(session)
|
||||
elif not isinstance(session, Session):
|
||||
raise ValueError(
|
||||
'The given session must be a str or a Session instance.')
|
||||
|
||||
super().__init__(
|
||||
session, api_id, api_hash,
|
||||
connection_mode=connection_mode,
|
||||
proxy=proxy,
|
||||
process_updates=process_updates,
|
||||
timeout=timeout
|
||||
update_workers=update_workers,
|
||||
spawn_read_thread=spawn_read_thread,
|
||||
timeout=timeout,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
# Used on connection - the user may modify these and reconnect
|
||||
kwargs['app_version'] = kwargs.get('app_version', self.__version__)
|
||||
for name, value in kwargs.items():
|
||||
if hasattr(self.session, name):
|
||||
setattr(self.session, name, value)
|
||||
|
||||
self._updates_thread = None
|
||||
# Some fields to easy signing in
|
||||
self._phone_code_hash = None
|
||||
self._phone = None
|
||||
|
||||
# Uploaded files cache so subsequent calls are instant
|
||||
self._upload_cache = {}
|
||||
|
||||
# Constantly read for results and updates from within the main client
|
||||
self._recv_thread = None
|
||||
|
||||
# Default PingRequest delay
|
||||
self._last_ping = datetime.now()
|
||||
self._ping_delay = timedelta(minutes=1)
|
||||
|
||||
# endregion
|
||||
|
||||
# region Connecting
|
||||
|
||||
def connect(self, exported_auth=None):
|
||||
"""Connects to the Telegram servers, executing authentication if
|
||||
required. Note that authenticating to the Telegram servers is
|
||||
not the same as authenticating the desired user itself, which
|
||||
may require a call (or several) to 'sign_in' for the first time.
|
||||
|
||||
exported_auth is meant for internal purposes and can be ignored.
|
||||
"""
|
||||
if self._sender and self._sender.is_connected():
|
||||
return
|
||||
|
||||
ok = super().connect(exported_auth=exported_auth)
|
||||
# The main TelegramClient is the only one that will have
|
||||
# constant_read, since it's also the only one who receives
|
||||
# updates and need to be processed as soon as they occur.
|
||||
#
|
||||
# TODO Allow to disable this to avoid the creation of a new thread
|
||||
# if the user is not going to work with updates at all? Whether to
|
||||
# read constantly or not for updates needs to be known before hand,
|
||||
# and further updates won't be able to be added unless allowing to
|
||||
# switch the mode on the fly.
|
||||
if ok:
|
||||
self._recv_thread = Thread(
|
||||
name='ReadThread', daemon=True,
|
||||
target=self._recv_thread_impl
|
||||
)
|
||||
self._recv_thread.start()
|
||||
if self.updates.polling:
|
||||
self.sync_updates()
|
||||
|
||||
return ok
|
||||
|
||||
def disconnect(self):
|
||||
"""Disconnects from the Telegram server
|
||||
and stops all the spawned threads"""
|
||||
if not self._sender or not self._sender.is_connected():
|
||||
return
|
||||
|
||||
# The existing thread will close eventually, since it's
|
||||
# only running while the MtProtoSender.is_connected()
|
||||
self._recv_thread = None
|
||||
|
||||
# This will trigger a "ConnectionResetError", usually, the background
|
||||
# thread would try restarting the connection but since the
|
||||
# ._recv_thread = None, it knows it doesn't have to.
|
||||
super().disconnect()
|
||||
|
||||
# Also disconnect all the cached senders
|
||||
for sender in self._cached_clients.values():
|
||||
sender.disconnect()
|
||||
|
||||
self._cached_clients.clear()
|
||||
|
||||
# endregion
|
||||
|
||||
# region Working with different connections
|
||||
|
||||
def create_new_connection(self, on_dc=None, timeout=timedelta(seconds=5)):
|
||||
"""Creates a new connection which can be used in parallel
|
||||
with the original TelegramClient. A TelegramBareClient
|
||||
will be returned already connected, and the caller is
|
||||
responsible to disconnect it.
|
||||
|
||||
If 'on_dc' is None, the new client will run on the same
|
||||
data center as the current client (most common case).
|
||||
|
||||
If the client is meant to be used on a different data
|
||||
center, the data center ID should be specified instead.
|
||||
"""
|
||||
if on_dc is None:
|
||||
client = TelegramBareClient(
|
||||
self.session, self.api_id, self.api_hash,
|
||||
proxy=self.proxy, timeout=timeout
|
||||
)
|
||||
client.connect()
|
||||
else:
|
||||
client = self._get_exported_client(on_dc, bypass_cache=True)
|
||||
|
||||
return client
|
||||
|
||||
# endregion
|
||||
|
||||
# region Telegram requests functions
|
||||
|
||||
def invoke(self, request, *args, **kwargs):
|
||||
"""Invokes (sends) a MTProtoRequest and returns (receives) its result.
|
||||
An optional 'retries' parameter can be set.
|
||||
|
||||
*args will be ignored.
|
||||
"""
|
||||
if self._recv_thread is not None and \
|
||||
threading.get_ident() == self._recv_thread.ident:
|
||||
raise AssertionError('Cannot invoke requests from the ReadThread')
|
||||
|
||||
self.updates.check_error()
|
||||
|
||||
try:
|
||||
# Users may call this method from within some update handler.
|
||||
# If this is the case, then the thread invoking the request
|
||||
# will be the one which should be reading (but is invoking the
|
||||
# request) thus not being available to read it "in the background"
|
||||
# and it's needed to call receive.
|
||||
return super().invoke(
|
||||
request, call_receive=self._recv_thread is None,
|
||||
retries=kwargs.get('retries', 5)
|
||||
)
|
||||
|
||||
except (PhoneMigrateError, NetworkMigrateError, UserMigrateError) as e:
|
||||
self._logger.debug('DC error when invoking request, '
|
||||
'attempting to reconnect at DC {}'
|
||||
.format(e.new_dc))
|
||||
|
||||
self.reconnect(new_dc=e.new_dc)
|
||||
return self.invoke(request)
|
||||
|
||||
# Let people use client(SomeRequest()) instead client.invoke(...)
|
||||
__call__ = invoke
|
||||
|
||||
def invoke_on_dc(self, request, dc_id, reconnect=False):
|
||||
"""Invokes the given request on a different DC
|
||||
by making use of the exported MtProtoSenders.
|
||||
|
||||
If 'reconnect=True', then the a reconnection will be performed and
|
||||
ConnectionResetError will be raised if it occurs a second time.
|
||||
"""
|
||||
try:
|
||||
client = self._get_exported_client(
|
||||
dc_id, init_connection=reconnect)
|
||||
|
||||
return client.invoke(request)
|
||||
|
||||
except ConnectionResetError:
|
||||
if reconnect:
|
||||
raise
|
||||
else:
|
||||
return self.invoke_on_dc(request, dc_id, reconnect=True)
|
||||
|
||||
# region Authorization requests
|
||||
|
||||
def is_user_authorized(self):
|
||||
"""Has the user been authorized yet
|
||||
(code request sent and confirmed)?"""
|
||||
return self.session and self.get_me() is not None
|
||||
|
||||
def send_code_request(self, phone):
|
||||
"""Sends a code request to the specified phone number"""
|
||||
result = self(
|
||||
SendCodeRequest(phone, self.api_id, self.api_hash))
|
||||
if isinstance(phone, int):
|
||||
phone = str(phone)
|
||||
elif phone.startswith('+'):
|
||||
phone = phone.strip('+')
|
||||
|
||||
result = self(SendCodeRequest(phone, self.api_id, self.api_hash))
|
||||
self._phone = phone
|
||||
self._phone_code_hash = result.phone_code_hash
|
||||
return result
|
||||
|
||||
def sign_in(self, phone=None, code=None,
|
||||
password=None, bot_token=None):
|
||||
password=None, bot_token=None, phone_code_hash=None):
|
||||
"""Completes the sign in process with the phone number + code pair.
|
||||
|
||||
If no phone or code is provided, then the sole password will be used.
|
||||
The password should be used after a normal authorization attempt
|
||||
has happened and an SessionPasswordNeededError was raised.
|
||||
|
||||
If you're calling .sign_in() on two completely different clients
|
||||
(for example, through an API that creates a new client per phone),
|
||||
you must first call .sign_in(phone) to receive the code, and then
|
||||
with the result such method results, call
|
||||
.sign_in(phone, code, phone_code_hash=result.phone_code_hash).
|
||||
|
||||
If this is done on the same client, the client will fill said values
|
||||
for you.
|
||||
|
||||
To login as a bot, only `bot_token` should be provided.
|
||||
This should equal to the bot access hash provided by
|
||||
https://t.me/BotFather during your bot creation.
|
||||
|
@ -306,64 +157,66 @@ class TelegramClient(TelegramBareClient):
|
|||
if phone and not code:
|
||||
return self.send_code_request(phone)
|
||||
elif code:
|
||||
if self._phone is None:
|
||||
phone = phone or self._phone
|
||||
phone_code_hash = phone_code_hash or self._phone_code_hash
|
||||
if not phone:
|
||||
raise ValueError(
|
||||
'Please make sure to call send_code_request first.')
|
||||
'Please make sure to call send_code_request first.'
|
||||
)
|
||||
if not phone_code_hash:
|
||||
raise ValueError('You also need to provide a phone_code_hash.')
|
||||
|
||||
try:
|
||||
if isinstance(code, int):
|
||||
code = str(code)
|
||||
result = self(SignInRequest(
|
||||
self._phone, self._phone_code_hash, code
|
||||
phone, phone_code_hash, code
|
||||
))
|
||||
|
||||
except (PhoneCodeEmptyError, PhoneCodeExpiredError,
|
||||
PhoneCodeHashEmptyError, PhoneCodeInvalidError):
|
||||
return None
|
||||
|
||||
elif password:
|
||||
salt = self(GetPasswordRequest()).current_salt
|
||||
result = self(
|
||||
CheckPasswordRequest(utils.get_password_hash(password, salt)))
|
||||
|
||||
result = self(CheckPasswordRequest(
|
||||
utils.get_password_hash(password, salt)
|
||||
))
|
||||
elif bot_token:
|
||||
result = self(ImportBotAuthorizationRequest(
|
||||
flags=0, bot_auth_token=bot_token,
|
||||
api_id=self.api_id, api_hash=self.api_hash))
|
||||
|
||||
api_id=self.api_id, api_hash=self.api_hash
|
||||
))
|
||||
else:
|
||||
raise ValueError(
|
||||
'You must provide a phone and a code the first time, '
|
||||
'and a password only if an RPCError was raised before.')
|
||||
'and a password only if an RPCError was raised before.'
|
||||
)
|
||||
|
||||
self._set_connected_and_authorized()
|
||||
return result.user
|
||||
|
||||
def sign_up(self, code, first_name, last_name=''):
|
||||
"""Signs up to Telegram. Make sure you sent a code request first!"""
|
||||
return self(SignUpRequest(
|
||||
result = self(SignUpRequest(
|
||||
phone_number=self._phone,
|
||||
phone_code_hash=self._phone_code_hash,
|
||||
phone_code=code,
|
||||
first_name=first_name,
|
||||
last_name=last_name
|
||||
)).user
|
||||
))
|
||||
|
||||
self._set_connected_and_authorized()
|
||||
return result.user
|
||||
|
||||
def log_out(self):
|
||||
"""Logs out and deletes the current session.
|
||||
Returns True if everything went okay."""
|
||||
# Special flag when logging out (so the ack request confirms it)
|
||||
self._sender.logging_out = True
|
||||
|
||||
try:
|
||||
self(LogOutRequest())
|
||||
# The server may have already disconnected us, we still
|
||||
# try to disconnect to make sure.
|
||||
self.disconnect()
|
||||
except (RPCError, ConnectionError):
|
||||
# Something happened when logging out, restore the state back
|
||||
self._sender.logging_out = False
|
||||
except RPCError:
|
||||
return False
|
||||
|
||||
self.disconnect()
|
||||
self.session.delete()
|
||||
self.session = None
|
||||
return True
|
||||
|
@ -386,22 +239,61 @@ class TelegramClient(TelegramBareClient):
|
|||
offset_id=0,
|
||||
offset_peer=InputPeerEmpty()):
|
||||
"""Returns a tuple of lists ([dialogs], [entities])
|
||||
with at least 'limit' items each.
|
||||
with at least 'limit' items each unless all dialogs were consumed.
|
||||
|
||||
If `limit` is None, all dialogs will be retrieved (from the given
|
||||
offset) will be retrieved.
|
||||
|
||||
If `limit` is 0, all dialogs will (should) retrieved.
|
||||
The `entities` represent the user, chat or channel
|
||||
corresponding to that dialog.
|
||||
corresponding to that dialog. If it's an integer, not
|
||||
all dialogs may be retrieved at once.
|
||||
"""
|
||||
if limit is None:
|
||||
limit = float('inf')
|
||||
|
||||
r = self(
|
||||
GetDialogsRequest(
|
||||
dialogs = {} # Use Dialog.top_message as identifier to avoid dupes
|
||||
messages = {} # Used later for sorting TODO also return these?
|
||||
entities = {}
|
||||
while len(dialogs) < limit:
|
||||
r = self(GetDialogsRequest(
|
||||
offset_date=offset_date,
|
||||
offset_id=offset_id,
|
||||
offset_peer=offset_peer,
|
||||
limit=limit))
|
||||
limit=0 # limit 0 often means "as much as possible"
|
||||
))
|
||||
if not r.dialogs:
|
||||
break
|
||||
|
||||
for d in r.dialogs:
|
||||
dialogs[d.top_message] = d
|
||||
for m in r.messages:
|
||||
messages[m.id] = m
|
||||
|
||||
# We assume users can't have the same ID as a chat
|
||||
for u in r.users:
|
||||
entities[u.id] = u
|
||||
for c in r.chats:
|
||||
entities[c.id] = c
|
||||
|
||||
if isinstance(r, DialogsSlice):
|
||||
# Don't enter next iteration if we already got all
|
||||
break
|
||||
|
||||
offset_date = r.messages[-1].date
|
||||
offset_peer = find_user_or_chat(r.dialogs[-1].peer, entities,
|
||||
entities)
|
||||
offset_id = r.messages[-1].id & 4294967296 # Telegram/danog magic
|
||||
|
||||
# Sort by message date
|
||||
no_date = datetime.fromtimestamp(0)
|
||||
dialogs = sorted(
|
||||
list(dialogs.values()),
|
||||
key=lambda d: getattr(messages[d.top_message], 'date', no_date)
|
||||
)
|
||||
return (
|
||||
r.dialogs,
|
||||
[find_user_or_chat(d.peer, r.users, r.chats) for d in r.dialogs])
|
||||
dialogs,
|
||||
[find_user_or_chat(d.peer, entities, entities) for d in dialogs]
|
||||
)
|
||||
|
||||
# endregion
|
||||
|
||||
|
@ -427,7 +319,7 @@ class TelegramClient(TelegramBareClient):
|
|||
reply_to_msg_id=self._get_reply_to(reply_to)
|
||||
)
|
||||
result = self(request)
|
||||
if isinstance(request, UpdateShortSentMessage):
|
||||
if isinstance(result, UpdateShortSentMessage):
|
||||
return Message(
|
||||
id=result.id,
|
||||
to_id=entity,
|
||||
|
@ -540,7 +432,7 @@ class TelegramClient(TelegramBareClient):
|
|||
return reply_to
|
||||
|
||||
if isinstance(reply_to, TLObject) and \
|
||||
type(reply_to).subclass_of_id == 0x790009e3:
|
||||
type(reply_to).SUBCLASS_OF_ID == 0x790009e3:
|
||||
# hex(crc32(b'Message')) = 0x790009e3
|
||||
return reply_to.id
|
||||
|
||||
|
@ -972,77 +864,3 @@ class TelegramClient(TelegramBareClient):
|
|||
)
|
||||
|
||||
# endregion
|
||||
|
||||
# region Updates handling
|
||||
|
||||
def sync_updates(self):
|
||||
"""Synchronizes self.updates to their initial state. Will be
|
||||
called automatically on connection if self.updates.enabled = True,
|
||||
otherwise it should be called manually after enabling updates.
|
||||
"""
|
||||
try:
|
||||
self.updates.process(self(GetStateRequest()))
|
||||
return True
|
||||
except UnauthorizedError:
|
||||
return False
|
||||
|
||||
def add_update_handler(self, handler):
|
||||
"""Adds an update handler (a function which takes a TLObject,
|
||||
an update, as its parameter) and listens for updates"""
|
||||
sync = not self.updates.handlers
|
||||
self.updates.handlers.append(handler)
|
||||
if sync:
|
||||
self.sync_updates()
|
||||
|
||||
def remove_update_handler(self, handler):
|
||||
self.updates.handlers.remove(handler)
|
||||
|
||||
def list_update_handlers(self):
|
||||
return self.updates.handlers[:]
|
||||
|
||||
# endregion
|
||||
|
||||
# Constant read
|
||||
|
||||
# By using this approach, another thread will be
|
||||
# created and started upon connection to constantly read
|
||||
# from the other end. Otherwise, manual calls to .receive()
|
||||
# must be performed. The MtProtoSender cannot be connected,
|
||||
# or an error will be thrown.
|
||||
#
|
||||
# This way, sending and receiving will be completely independent.
|
||||
def _recv_thread_impl(self):
|
||||
while self._sender and self._sender.is_connected():
|
||||
try:
|
||||
if datetime.now() > self._last_ping + self._ping_delay:
|
||||
self._sender.send(PingRequest(
|
||||
int.from_bytes(os.urandom(8), 'big', signed=True)
|
||||
))
|
||||
self._last_ping = datetime.now()
|
||||
|
||||
self._sender.receive(update_state=self.updates)
|
||||
except AttributeError:
|
||||
# 'NoneType' object has no attribute 'receive'.
|
||||
# The only moment when this can happen is reconnection
|
||||
# was triggered from another thread and the ._sender
|
||||
# was set to None, so close this thread and exit by return.
|
||||
self._recv_thread = None
|
||||
return
|
||||
except TimeoutError:
|
||||
# No problem.
|
||||
pass
|
||||
except ConnectionResetError:
|
||||
if self._recv_thread is not None:
|
||||
# Do NOT attempt reconnecting unless the connection was
|
||||
# finished by the user -> ._recv_thread is None
|
||||
self._logger.debug('Server disconnected us. Reconnecting...')
|
||||
self._recv_thread = None # Not running anymore
|
||||
self.reconnect()
|
||||
return
|
||||
except Exception as e:
|
||||
# Unknown exception, pass it to the main thread
|
||||
self.updates.set_error(e)
|
||||
self._recv_thread = None
|
||||
return
|
||||
|
||||
# endregion
|
||||
|
|
|
@ -1,2 +1,5 @@
|
|||
from .tlobject import TLObject
|
||||
from .session import Session
|
||||
from .gzip_packed import GzipPacked
|
||||
from .tl_message import TLMessage
|
||||
from .message_container import MessageContainer
|
||||
|
|
38
telethon/tl/gzip_packed.py
Normal file
38
telethon/tl/gzip_packed.py
Normal file
|
@ -0,0 +1,38 @@
|
|||
import gzip
|
||||
import struct
|
||||
|
||||
from . import TLObject
|
||||
|
||||
|
||||
class GzipPacked(TLObject):
|
||||
CONSTRUCTOR_ID = 0x3072cfa1
|
||||
|
||||
def __init__(self, data):
|
||||
super().__init__()
|
||||
self.data = data
|
||||
|
||||
@staticmethod
|
||||
def gzip_if_smaller(request):
|
||||
"""Calls request.to_bytes(), and based on a certain threshold,
|
||||
optionally gzips the resulting data. If the gzipped data is
|
||||
smaller than the original byte array, this is returned instead.
|
||||
|
||||
Note that this only applies to content related requests.
|
||||
"""
|
||||
data = request.to_bytes()
|
||||
# TODO This threshold could be configurable
|
||||
if request.content_related and len(data) > 512:
|
||||
gzipped = GzipPacked(data).to_bytes()
|
||||
return gzipped if len(gzipped) < len(data) else data
|
||||
else:
|
||||
return data
|
||||
|
||||
def to_bytes(self):
|
||||
# TODO Maybe compress level could be an option
|
||||
return struct.pack('<I', GzipPacked.CONSTRUCTOR_ID) + \
|
||||
TLObject.serialize_bytes(gzip.compress(self.data))
|
||||
|
||||
@staticmethod
|
||||
def read(reader):
|
||||
reader.read_int(signed=False) # code
|
||||
return gzip.decompress(reader.tgread_bytes())
|
27
telethon/tl/message_container.py
Normal file
27
telethon/tl/message_container.py
Normal file
|
@ -0,0 +1,27 @@
|
|||
import struct
|
||||
|
||||
from . import TLObject
|
||||
|
||||
|
||||
class MessageContainer(TLObject):
|
||||
CONSTRUCTOR_ID = 0x73f1f8dc
|
||||
|
||||
def __init__(self, messages):
|
||||
super().__init__()
|
||||
self.content_related = False
|
||||
self.messages = messages
|
||||
|
||||
def to_bytes(self):
|
||||
return struct.pack(
|
||||
'<Ii', MessageContainer.CONSTRUCTOR_ID, len(self.messages)
|
||||
) + b''.join(m.to_bytes() for m in self.messages)
|
||||
|
||||
@staticmethod
|
||||
def iter_read(reader):
|
||||
reader.read_int(signed=False) # code
|
||||
size = reader.read_int()
|
||||
for _ in range(size):
|
||||
inner_msg_id = reader.read_long()
|
||||
inner_sequence = reader.read_int()
|
||||
inner_length = reader.read_int()
|
||||
yield inner_msg_id, inner_sequence, inner_length
|
|
@ -118,7 +118,7 @@ class Session:
|
|||
# FIXME We need to import the AuthKey here or otherwise
|
||||
# we get cyclic dependencies.
|
||||
from ..crypto import AuthKey
|
||||
if data['auth_key_data'] is not None:
|
||||
if data.get('auth_key_data', None) is not None:
|
||||
key = b64decode(data['auth_key_data'])
|
||||
result.auth_key = AuthKey(data=key)
|
||||
|
||||
|
|
17
telethon/tl/tl_message.py
Normal file
17
telethon/tl/tl_message.py
Normal file
|
@ -0,0 +1,17 @@
|
|||
import struct
|
||||
|
||||
from . import TLObject, GzipPacked
|
||||
|
||||
|
||||
class TLMessage(TLObject):
|
||||
"""https://core.telegram.org/mtproto/service_messages#simple-container"""
|
||||
def __init__(self, session, request):
|
||||
super().__init__()
|
||||
del self.content_related
|
||||
self.msg_id = session.get_new_msg_id()
|
||||
self.seq_no = session.generate_sequence(request.content_related)
|
||||
self.request = request
|
||||
|
||||
def to_bytes(self):
|
||||
body = GzipPacked.gzip_if_smaller(self.request)
|
||||
return struct.pack('<qii', self.msg_id, self.seq_no, len(body)) + body
|
|
@ -9,7 +9,6 @@ class TLObject:
|
|||
self.rpc_error = None
|
||||
|
||||
# These should be overrode
|
||||
self.constructor_id = 0
|
||||
self.content_related = False # Only requests/functions/queries are
|
||||
|
||||
# These should not be overrode
|
||||
|
@ -20,10 +19,13 @@ class TLObject:
|
|||
"""
|
||||
if indent is None:
|
||||
if isinstance(obj, TLObject):
|
||||
return '{{{}: {}}}'.format(
|
||||
type(obj).__name__,
|
||||
TLObject.pretty_format(obj.to_dict())
|
||||
)
|
||||
children = obj.to_dict(recursive=False)
|
||||
if children:
|
||||
return '{}: {}'.format(
|
||||
type(obj).__name__, TLObject.pretty_format(children)
|
||||
)
|
||||
else:
|
||||
return type(obj).__name__
|
||||
if isinstance(obj, dict):
|
||||
return '{{{}}}'.format(', '.join(
|
||||
'{}: {}'.format(
|
||||
|
@ -41,12 +43,13 @@ class TLObject:
|
|||
else:
|
||||
result = []
|
||||
if isinstance(obj, TLObject):
|
||||
result.append('{')
|
||||
result.append(type(obj).__name__)
|
||||
result.append(': ')
|
||||
result.append(TLObject.pretty_format(
|
||||
obj.to_dict(), indent
|
||||
))
|
||||
children = obj.to_dict(recursive=False)
|
||||
if children:
|
||||
result.append(': ')
|
||||
result.append(TLObject.pretty_format(
|
||||
obj.to_dict(recursive=False), indent
|
||||
))
|
||||
|
||||
elif isinstance(obj, dict):
|
||||
result.append('{\n')
|
||||
|
@ -80,12 +83,43 @@ class TLObject:
|
|||
|
||||
return ''.join(result)
|
||||
|
||||
@staticmethod
|
||||
def serialize_bytes(data):
|
||||
"""Write bytes by using Telegram guidelines"""
|
||||
if isinstance(data, str):
|
||||
data = data.encode('utf-8')
|
||||
|
||||
r = []
|
||||
if len(data) < 254:
|
||||
padding = (len(data) + 1) % 4
|
||||
if padding != 0:
|
||||
padding = 4 - padding
|
||||
|
||||
r.append(bytes([len(data)]))
|
||||
r.append(data)
|
||||
|
||||
else:
|
||||
padding = len(data) % 4
|
||||
if padding != 0:
|
||||
padding = 4 - padding
|
||||
|
||||
r.append(bytes([
|
||||
254,
|
||||
len(data) % 256,
|
||||
(len(data) >> 8) % 256,
|
||||
(len(data) >> 16) % 256
|
||||
]))
|
||||
r.append(data)
|
||||
|
||||
r.append(bytes(padding))
|
||||
return b''.join(r)
|
||||
|
||||
# These should be overrode
|
||||
def to_dict(self):
|
||||
def to_dict(self, recursive=True):
|
||||
return {}
|
||||
|
||||
def on_send(self, writer):
|
||||
pass
|
||||
def to_bytes(self):
|
||||
return b''
|
||||
|
||||
def on_response(self, reader):
|
||||
pass
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
import logging
|
||||
from collections import deque
|
||||
from datetime import datetime
|
||||
from threading import RLock, Event
|
||||
from threading import RLock, Event, Thread
|
||||
|
||||
from .tl import types as tl
|
||||
|
||||
|
@ -9,78 +10,136 @@ class UpdateState:
|
|||
"""Used to hold the current state of processed updates.
|
||||
To retrieve an update, .poll() should be called.
|
||||
"""
|
||||
def __init__(self, polling):
|
||||
self._polling = polling
|
||||
WORKER_POLL_TIMEOUT = 5.0 # Avoid waiting forever on the workers
|
||||
|
||||
def __init__(self, workers=None):
|
||||
"""
|
||||
:param workers: This integer parameter has three possible cases:
|
||||
workers is None: Updates will *not* be stored on self.
|
||||
workers = 0: Another thread is responsible for calling self.poll()
|
||||
workers > 0: 'workers' background threads will be spawned, any
|
||||
any of them will invoke all the self.handlers.
|
||||
"""
|
||||
self._workers = workers
|
||||
self._worker_threads = []
|
||||
|
||||
self.handlers = []
|
||||
self._updates_lock = RLock()
|
||||
self._updates_available = Event()
|
||||
self._updates = deque()
|
||||
|
||||
self._logger = logging.getLogger(__name__)
|
||||
|
||||
# https://core.telegram.org/api/updates
|
||||
self._state = tl.updates.State(0, 0, datetime.now(), 0, 0)
|
||||
self._setup_workers()
|
||||
|
||||
def can_poll(self):
|
||||
"""Returns True if a call to .poll() won't lock"""
|
||||
return self._updates_available.is_set()
|
||||
|
||||
def poll(self):
|
||||
"""Polls an update or blocks until an update object is available"""
|
||||
if not self._polling:
|
||||
raise ValueError('Updates are not being polled hence not saved.')
|
||||
def poll(self, timeout=None):
|
||||
"""Polls an update or blocks until an update object is available.
|
||||
If 'timeout is not None', it should be a floating point value,
|
||||
and the method will 'return None' if waiting times out.
|
||||
"""
|
||||
if not self._updates_available.wait(timeout=timeout):
|
||||
return
|
||||
|
||||
self._updates_available.wait()
|
||||
with self._updates_lock:
|
||||
if not self._updates_available.is_set():
|
||||
return
|
||||
|
||||
update = self._updates.popleft()
|
||||
if not self._updates:
|
||||
self._updates_available.clear()
|
||||
|
||||
if isinstance(update, Exception):
|
||||
raise update # Some error was set through .set_error()
|
||||
raise update # Some error was set through (surely StopIteration)
|
||||
|
||||
return update
|
||||
|
||||
def get_polling(self):
|
||||
return self._polling
|
||||
def get_workers(self):
|
||||
return self._workers
|
||||
|
||||
def set_polling(self, polling):
|
||||
self._polling = polling
|
||||
if not polling:
|
||||
with self._updates_lock:
|
||||
self._updates.clear()
|
||||
|
||||
polling = property(fget=get_polling, fset=set_polling)
|
||||
|
||||
def set_error(self, error):
|
||||
"""Sets an error, so that the next call to .poll() will raise it.
|
||||
Can be (and is) used to pass exceptions between threads.
|
||||
def set_workers(self, n):
|
||||
"""Changes the number of workers running.
|
||||
If 'n is None', clears all pending updates from memory.
|
||||
"""
|
||||
with self._updates_lock:
|
||||
# Insert at the beginning so the very next poll causes an error
|
||||
# TODO Should this reset the pts and such?
|
||||
self._updates.insert(0, error)
|
||||
self._updates_available.set()
|
||||
self._stop_workers()
|
||||
self._workers = n
|
||||
if n is None:
|
||||
self._updates.clear()
|
||||
else:
|
||||
self._setup_workers()
|
||||
|
||||
def check_error(self):
|
||||
with self._updates_lock:
|
||||
if self._updates and isinstance(self._updates[0], Exception):
|
||||
raise self._updates.pop()
|
||||
workers = property(fget=get_workers, fset=set_workers)
|
||||
|
||||
def _stop_workers(self):
|
||||
"""Raises "StopIterationException" on the worker threads to stop them,
|
||||
and also clears all of them off the list
|
||||
"""
|
||||
if self._workers:
|
||||
with self._updates_lock:
|
||||
# Insert at the beginning so the very next poll causes an error
|
||||
# on all the worker threads
|
||||
# TODO Should this reset the pts and such?
|
||||
for _ in range(self._workers):
|
||||
self._updates.appendleft(StopIteration())
|
||||
self._updates_available.set()
|
||||
|
||||
for t in self._worker_threads:
|
||||
t.join()
|
||||
|
||||
self._worker_threads.clear()
|
||||
|
||||
def _setup_workers(self):
|
||||
if self._worker_threads or not self._workers:
|
||||
# There already are workers, or workers is None or 0. Do nothing.
|
||||
return
|
||||
|
||||
for i in range(self._workers):
|
||||
thread = Thread(
|
||||
target=UpdateState._worker_loop,
|
||||
name='UpdateWorker{}'.format(i),
|
||||
daemon=True,
|
||||
args=(self, i)
|
||||
)
|
||||
self._worker_threads.append(thread)
|
||||
thread.start()
|
||||
|
||||
def _worker_loop(self, wid):
|
||||
while True:
|
||||
try:
|
||||
update = self.poll(timeout=UpdateState.WORKER_POLL_TIMEOUT)
|
||||
# TODO Maybe people can add different handlers per update type
|
||||
if update:
|
||||
for handler in self.handlers:
|
||||
handler(update)
|
||||
except StopIteration:
|
||||
break
|
||||
except Exception as e:
|
||||
# We don't want to crash a worker thread due to any reason
|
||||
self._logger.debug(
|
||||
'[ERROR] Unhandled exception on worker {}'.format(wid), e
|
||||
)
|
||||
|
||||
def process(self, update):
|
||||
"""Processes an update object. This method is normally called by
|
||||
the library itself.
|
||||
"""
|
||||
if not self._polling and not self.handlers:
|
||||
return
|
||||
if self._workers is None:
|
||||
return # No processing needs to be done if nobody's working
|
||||
|
||||
with self._updates_lock:
|
||||
if isinstance(update, tl.updates.State):
|
||||
self._state = update
|
||||
elif not hasattr(update, 'pts') or update.pts > self._state.pts:
|
||||
self._state.pts = getattr(update, 'pts', self._state.pts)
|
||||
return # Nothing else to be done
|
||||
|
||||
if self._polling:
|
||||
self._updates.append(update)
|
||||
self._updates_available.set()
|
||||
pts = getattr(update, 'pts', self._state.pts)
|
||||
if hasattr(update, 'pts') and pts <= self._state.pts:
|
||||
return # We already handled this update
|
||||
|
||||
for handler in self.handlers:
|
||||
handler(update)
|
||||
self._state.pts = pts
|
||||
self._updates.append(update)
|
||||
self._updates_available.set()
|
||||
|
|
|
@ -10,8 +10,17 @@ from .tl.types import (
|
|||
ChatPhoto, InputPeerChannel, InputPeerChat, InputPeerUser, InputPeerEmpty,
|
||||
MessageMediaDocument, MessageMediaPhoto, PeerChannel, InputChannel,
|
||||
UserEmpty, InputUser, InputUserEmpty, InputUserSelf, InputPeerSelf,
|
||||
PeerChat, PeerUser, User, UserFull, UserProfilePhoto, Document
|
||||
)
|
||||
PeerChat, PeerUser, User, UserFull, UserProfilePhoto, Document,
|
||||
MessageMediaContact, MessageMediaEmpty, MessageMediaGame, MessageMediaGeo,
|
||||
MessageMediaUnsupported, MessageMediaVenue, InputMediaContact,
|
||||
InputMediaDocument, InputMediaEmpty, InputMediaGame,
|
||||
InputMediaGeoPoint, InputMediaPhoto, InputMediaVenue, InputDocument,
|
||||
DocumentEmpty, InputDocumentEmpty, Message, GeoPoint, InputGeoPoint,
|
||||
GeoPointEmpty, InputGeoPointEmpty, Photo, InputPhoto, PhotoEmpty,
|
||||
InputPhotoEmpty, FileLocation, ChatPhotoEmpty, UserProfilePhotoEmpty,
|
||||
FileLocationUnavailable, InputMediaUploadedDocument,
|
||||
InputMediaUploadedPhoto,
|
||||
DocumentAttributeFilename)
|
||||
|
||||
|
||||
def get_display_name(entity):
|
||||
|
@ -65,13 +74,10 @@ def _raise_cast_fail(entity, target):
|
|||
def get_input_peer(entity):
|
||||
"""Gets the input peer for the given "entity" (user, chat or channel).
|
||||
A ValueError is raised if the given entity isn't a supported type."""
|
||||
if entity is None:
|
||||
return None
|
||||
|
||||
if not isinstance(entity, TLObject):
|
||||
_raise_cast_fail(entity, 'InputPeer')
|
||||
|
||||
if type(entity).subclass_of_id == 0xc91c90b6: # crc32(b'InputPeer')
|
||||
if type(entity).SUBCLASS_OF_ID == 0xc91c90b6: # crc32(b'InputPeer')
|
||||
return entity
|
||||
|
||||
if isinstance(entity, User):
|
||||
|
@ -109,13 +115,10 @@ def get_input_peer(entity):
|
|||
|
||||
def get_input_channel(entity):
|
||||
"""Similar to get_input_peer, but for InputChannel's alone"""
|
||||
if entity is None:
|
||||
return None
|
||||
|
||||
if not isinstance(entity, TLObject):
|
||||
_raise_cast_fail(entity, 'InputChannel')
|
||||
|
||||
if type(entity).subclass_of_id == 0x40f202fd: # crc32(b'InputChannel')
|
||||
if type(entity).SUBCLASS_OF_ID == 0x40f202fd: # crc32(b'InputChannel')
|
||||
return entity
|
||||
|
||||
if isinstance(entity, Channel) or isinstance(entity, ChannelForbidden):
|
||||
|
@ -129,13 +132,10 @@ def get_input_channel(entity):
|
|||
|
||||
def get_input_user(entity):
|
||||
"""Similar to get_input_peer, but for InputUser's alone"""
|
||||
if entity is None:
|
||||
return None
|
||||
|
||||
if not isinstance(entity, TLObject):
|
||||
_raise_cast_fail(entity, 'InputUser')
|
||||
|
||||
if type(entity).subclass_of_id == 0xe669bf46: # crc32(b'InputUser')
|
||||
if type(entity).SUBCLASS_OF_ID == 0xe669bf46: # crc32(b'InputUser')
|
||||
return entity
|
||||
|
||||
if isinstance(entity, User):
|
||||
|
@ -156,27 +156,169 @@ def get_input_user(entity):
|
|||
_raise_cast_fail(entity, 'InputUser')
|
||||
|
||||
|
||||
def get_input_document(document):
|
||||
"""Similar to get_input_peer, but for documents"""
|
||||
if not isinstance(document, TLObject):
|
||||
_raise_cast_fail(document, 'InputDocument')
|
||||
|
||||
if type(document).SUBCLASS_OF_ID == 0xf33fdb68: # crc32(b'InputDocument')
|
||||
return document
|
||||
|
||||
if isinstance(document, Document):
|
||||
return InputDocument(id=document.id, access_hash=document.access_hash)
|
||||
|
||||
if isinstance(document, DocumentEmpty):
|
||||
return InputDocumentEmpty()
|
||||
|
||||
if isinstance(document, MessageMediaDocument):
|
||||
return get_input_document(document.document)
|
||||
|
||||
if isinstance(document, Message):
|
||||
return get_input_document(document.media)
|
||||
|
||||
_raise_cast_fail(document, 'InputDocument')
|
||||
|
||||
|
||||
def get_input_photo(photo):
|
||||
"""Similar to get_input_peer, but for documents"""
|
||||
if not isinstance(photo, TLObject):
|
||||
_raise_cast_fail(photo, 'InputPhoto')
|
||||
|
||||
if type(photo).SUBCLASS_OF_ID == 0x846363e0: # crc32(b'InputPhoto')
|
||||
return photo
|
||||
|
||||
if isinstance(photo, Photo):
|
||||
return InputPhoto(id=photo.id, access_hash=photo.access_hash)
|
||||
|
||||
if isinstance(photo, PhotoEmpty):
|
||||
return InputPhotoEmpty()
|
||||
|
||||
_raise_cast_fail(photo, 'InputPhoto')
|
||||
|
||||
|
||||
def get_input_geo(geo):
|
||||
"""Similar to get_input_peer, but for geo points"""
|
||||
if not isinstance(geo, TLObject):
|
||||
_raise_cast_fail(geo, 'InputGeoPoint')
|
||||
|
||||
if type(geo).SUBCLASS_OF_ID == 0x430d225: # crc32(b'InputGeoPoint')
|
||||
return geo
|
||||
|
||||
if isinstance(geo, GeoPoint):
|
||||
return InputGeoPoint(lat=geo.lat, long=geo.long)
|
||||
|
||||
if isinstance(geo, GeoPointEmpty):
|
||||
return InputGeoPointEmpty()
|
||||
|
||||
if isinstance(geo, MessageMediaGeo):
|
||||
return get_input_geo(geo.geo)
|
||||
|
||||
if isinstance(geo, Message):
|
||||
return get_input_geo(geo.media)
|
||||
|
||||
_raise_cast_fail(geo, 'InputGeoPoint')
|
||||
|
||||
|
||||
def get_input_media(media, user_caption=None, is_photo=False):
|
||||
"""Similar to get_input_peer, but for media.
|
||||
|
||||
If the media is a file location and is_photo is known to be True,
|
||||
it will be treated as an InputMediaUploadedPhoto.
|
||||
"""
|
||||
if not isinstance(media, TLObject):
|
||||
_raise_cast_fail(media, 'InputMedia')
|
||||
|
||||
if type(media).SUBCLASS_OF_ID == 0xfaf846f4: # crc32(b'InputMedia')
|
||||
return media
|
||||
|
||||
if isinstance(media, MessageMediaPhoto):
|
||||
return InputMediaPhoto(
|
||||
id=get_input_photo(media.photo),
|
||||
caption=media.caption if user_caption is None else user_caption,
|
||||
ttl_seconds=media.ttl_seconds
|
||||
)
|
||||
|
||||
if isinstance(media, MessageMediaDocument):
|
||||
return InputMediaDocument(
|
||||
id=get_input_document(media.document),
|
||||
caption=media.caption if user_caption is None else user_caption,
|
||||
ttl_seconds=media.ttl_seconds
|
||||
)
|
||||
|
||||
if isinstance(media, FileLocation):
|
||||
if is_photo:
|
||||
return InputMediaUploadedPhoto(
|
||||
file=media,
|
||||
caption=user_caption or ''
|
||||
)
|
||||
else:
|
||||
return InputMediaUploadedDocument(
|
||||
file=media,
|
||||
mime_type='application/octet-stream', # unknown, assume bytes
|
||||
attributes=[DocumentAttributeFilename('unnamed')],
|
||||
caption=user_caption or ''
|
||||
)
|
||||
|
||||
if isinstance(media, MessageMediaGame):
|
||||
return InputMediaGame(id=media.game.id)
|
||||
|
||||
if isinstance(media, ChatPhoto) or isinstance(media, UserProfilePhoto):
|
||||
if isinstance(media.photo_big, FileLocationUnavailable):
|
||||
return get_input_media(media.photo_small, is_photo=True)
|
||||
else:
|
||||
return get_input_media(media.photo_big, is_photo=True)
|
||||
|
||||
if isinstance(media, MessageMediaContact):
|
||||
return InputMediaContact(
|
||||
phone_number=media.phone_number,
|
||||
first_name=media.first_name,
|
||||
last_name=media.last_name
|
||||
)
|
||||
|
||||
if isinstance(media, MessageMediaGeo):
|
||||
return InputMediaGeoPoint(geo_point=get_input_geo(media.geo))
|
||||
|
||||
if isinstance(media, MessageMediaVenue):
|
||||
return InputMediaVenue(
|
||||
geo_point=get_input_geo(media.geo),
|
||||
title=media.title,
|
||||
address=media.address,
|
||||
provider=media.provider,
|
||||
venue_id=media.venue_id
|
||||
)
|
||||
|
||||
if any(isinstance(media, t) for t in (
|
||||
MessageMediaEmpty, MessageMediaUnsupported,
|
||||
FileLocationUnavailable, ChatPhotoEmpty,
|
||||
UserProfilePhotoEmpty)):
|
||||
return InputMediaEmpty()
|
||||
|
||||
if isinstance(media, Message):
|
||||
return get_input_media(media.media)
|
||||
|
||||
_raise_cast_fail(media, 'InputMedia')
|
||||
|
||||
|
||||
def find_user_or_chat(peer, users, chats):
|
||||
"""Finds the corresponding user or chat given a peer.
|
||||
Returns None if it was not found"""
|
||||
try:
|
||||
if isinstance(peer, PeerUser):
|
||||
return next(u for u in users if u.id == peer.user_id)
|
||||
|
||||
elif isinstance(peer, PeerChat):
|
||||
return next(c for c in chats if c.id == peer.chat_id)
|
||||
|
||||
if isinstance(peer, PeerUser):
|
||||
peer, where = peer.user_id, users
|
||||
else:
|
||||
where = chats
|
||||
if isinstance(peer, PeerChat):
|
||||
peer = peer.chat_id
|
||||
elif isinstance(peer, PeerChannel):
|
||||
return next(c for c in chats if c.id == peer.channel_id)
|
||||
|
||||
except StopIteration: return
|
||||
peer = peer.channel_id
|
||||
|
||||
if isinstance(peer, int):
|
||||
try: return next(u for u in users if u.id == peer)
|
||||
except StopIteration: pass
|
||||
|
||||
try: return next(c for c in chats if c.id == peer)
|
||||
except StopIteration: pass
|
||||
if isinstance(where, dict):
|
||||
return where.get(peer)
|
||||
else:
|
||||
try:
|
||||
return next(x for x in where if x.id == peer)
|
||||
except StopIteration:
|
||||
pass
|
||||
|
||||
|
||||
def get_appropriated_part_size(file_size):
|
||||
|
|
|
@ -96,6 +96,17 @@ class TLObject:
|
|||
result=match.group(3),
|
||||
is_function=is_function)
|
||||
|
||||
def class_name(self):
|
||||
"""Gets the class name following the Python style guidelines"""
|
||||
|
||||
# Courtesy of http://stackoverflow.com/a/31531797/4759433
|
||||
result = re.sub(r'_([a-z])', lambda m: m.group(1).upper(), self.name)
|
||||
result = result[:1].upper() + result[1:].replace('_', '')
|
||||
# If it's a function, let it end with "Request" to identify them
|
||||
if self.is_function:
|
||||
result += 'Request'
|
||||
return result
|
||||
|
||||
def sorted_args(self):
|
||||
"""Returns the arguments properly sorted and ready to plug-in
|
||||
into a Python's method header (i.e., flags and those which
|
||||
|
@ -197,8 +208,8 @@ class TLArg:
|
|||
else:
|
||||
self.flag_indicator = False
|
||||
self.is_generic = arg_type.startswith('!')
|
||||
self.type = arg_type.lstrip(
|
||||
'!') # Strip the exclamation mark always to have only the name
|
||||
# Strip the exclamation mark always to have only the name
|
||||
self.type = arg_type.lstrip('!')
|
||||
|
||||
# The type may be a flag (flags.IDX?REAL_TYPE)
|
||||
# Note that 'flags' is NOT the flags name; this is determined by a previous argument
|
||||
|
@ -233,6 +244,24 @@ class TLArg:
|
|||
|
||||
self.generic_definition = generic_definition
|
||||
|
||||
def type_hint(self):
|
||||
result = {
|
||||
'int': 'int',
|
||||
'long': 'int',
|
||||
'int128': 'int',
|
||||
'int256': 'int',
|
||||
'string': 'str',
|
||||
'date': 'datetime.datetime | None', # None date = 0 timestamp
|
||||
'bytes': 'bytes',
|
||||
'true': 'bool',
|
||||
}.get(self.type, 'TLObject')
|
||||
if self.is_vector:
|
||||
result = 'list[{}]'.format(result)
|
||||
if self.is_flag and self.type != 'date':
|
||||
result += ' | None'
|
||||
|
||||
return result
|
||||
|
||||
def __str__(self):
|
||||
# Find the real type representation by updating it as required
|
||||
real_type = self.type
|
||||
|
|
|
@ -32,16 +32,16 @@
|
|||
/// Authorization key creation
|
||||
///////////////////////////////
|
||||
|
||||
resPQ#05162463 nonce:int128 server_nonce:int128 pq:string server_public_key_fingerprints:Vector<long> = ResPQ;
|
||||
resPQ#05162463 nonce:int128 server_nonce:int128 pq:bytes server_public_key_fingerprints:Vector<long> = ResPQ;
|
||||
|
||||
p_q_inner_data#83c95aec pq:string p:string q:string nonce:int128 server_nonce:int128 new_nonce:int256 = P_Q_inner_data;
|
||||
p_q_inner_data#83c95aec pq:bytes p:bytes q:bytes nonce:int128 server_nonce:int128 new_nonce:int256 = P_Q_inner_data;
|
||||
|
||||
server_DH_params_fail#79cb045d nonce:int128 server_nonce:int128 new_nonce_hash:int128 = Server_DH_Params;
|
||||
server_DH_params_ok#d0e8075c nonce:int128 server_nonce:int128 encrypted_answer:string = Server_DH_Params;
|
||||
server_DH_params_ok#d0e8075c nonce:int128 server_nonce:int128 encrypted_answer:bytes = Server_DH_Params;
|
||||
|
||||
server_DH_inner_data#b5890dba nonce:int128 server_nonce:int128 g:int dh_prime:string g_a:string server_time:int = Server_DH_inner_data;
|
||||
server_DH_inner_data#b5890dba nonce:int128 server_nonce:int128 g:int dh_prime:bytes g_a:bytes server_time:int = Server_DH_inner_data;
|
||||
|
||||
client_DH_inner_data#6643b654 nonce:int128 server_nonce:int128 retry_id:long g_b:string = Client_DH_Inner_Data;
|
||||
client_DH_inner_data#6643b654 nonce:int128 server_nonce:int128 retry_id:long g_b:bytes = Client_DH_Inner_Data;
|
||||
|
||||
dh_gen_ok#3bcbf734 nonce:int128 server_nonce:int128 new_nonce_hash1:int128 = Set_client_DH_params_answer;
|
||||
dh_gen_retry#46dc1fb9 nonce:int128 server_nonce:int128 new_nonce_hash2:int128 = Set_client_DH_params_answer;
|
||||
|
@ -55,9 +55,9 @@ destroy_auth_key_fail#ea109b13 = DestroyAuthKeyRes;
|
|||
|
||||
req_pq#60469778 nonce:int128 = ResPQ;
|
||||
|
||||
req_DH_params#d712e4be nonce:int128 server_nonce:int128 p:string q:string public_key_fingerprint:long encrypted_data:string = Server_DH_Params;
|
||||
req_DH_params#d712e4be nonce:int128 server_nonce:int128 p:bytes q:bytes public_key_fingerprint:long encrypted_data:bytes = Server_DH_Params;
|
||||
|
||||
set_client_DH_params#f5045f1f nonce:int128 server_nonce:int128 encrypted_data:string = Set_client_DH_params_answer;
|
||||
set_client_DH_params#f5045f1f nonce:int128 server_nonce:int128 encrypted_data:bytes = Set_client_DH_params_answer;
|
||||
|
||||
destroy_auth_key#d1435160 = DestroyAuthKeyRes;
|
||||
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
import os
|
||||
import re
|
||||
import shutil
|
||||
import struct
|
||||
from zlib import crc32
|
||||
from collections import defaultdict
|
||||
|
||||
|
@ -107,8 +108,7 @@ class TLGenerator:
|
|||
if tlobject.namespace:
|
||||
builder.write('.' + tlobject.namespace)
|
||||
|
||||
builder.writeln('.{},'.format(
|
||||
TLGenerator.get_class_name(tlobject)))
|
||||
builder.writeln('.{},'.format(tlobject.class_name()))
|
||||
|
||||
builder.current_indent -= 1
|
||||
builder.writeln('}')
|
||||
|
@ -137,13 +137,29 @@ class TLGenerator:
|
|||
x for x in namespace_tlobjects.keys() if x
|
||||
)))
|
||||
|
||||
# Import 'get_input_*' utils
|
||||
# TODO Support them on types too
|
||||
if 'functions' in out_dir:
|
||||
builder.writeln(
|
||||
'from {}.utils import get_input_peer, '
|
||||
'get_input_channel, get_input_user, '
|
||||
'get_input_media'.format('.' * depth)
|
||||
)
|
||||
|
||||
# Import 'os' for those needing access to 'os.urandom()'
|
||||
# Currently only 'random_id' needs 'os' to be imported,
|
||||
# for all those TLObjects with arg.can_be_inferred.
|
||||
builder.writeln('import os')
|
||||
|
||||
# Import struct for the .to_bytes(self) serialization
|
||||
builder.writeln('import struct')
|
||||
|
||||
# Generate the class for every TLObject
|
||||
for t in sorted(tlobjects, key=lambda x: x.name):
|
||||
TLGenerator._write_source_code(
|
||||
t, builder, depth, type_constructors
|
||||
)
|
||||
while builder.current_indent != 0:
|
||||
builder.end_block()
|
||||
builder.current_indent = 0
|
||||
|
||||
@staticmethod
|
||||
def _write_source_code(tlobject, builder, depth, type_constructors):
|
||||
|
@ -154,35 +170,15 @@ class TLGenerator:
|
|||
the Type: [Constructors] must be given for proper
|
||||
importing and documentation strings.
|
||||
"""
|
||||
if tlobject.is_function:
|
||||
util_imports = set()
|
||||
for a in tlobject.args:
|
||||
# We can automatically convert some "full" types to
|
||||
# "input only" (like User -> InputPeerUser, etc.)
|
||||
if a.type == 'InputPeer':
|
||||
util_imports.add('get_input_peer')
|
||||
elif a.type == 'InputChannel':
|
||||
util_imports.add('get_input_channel')
|
||||
elif a.type == 'InputUser':
|
||||
util_imports.add('get_input_user')
|
||||
|
||||
if util_imports:
|
||||
builder.writeln('from {}.utils import {}'.format(
|
||||
'.' * depth, ', '.join(util_imports)))
|
||||
|
||||
if any(a for a in tlobject.args if a.can_be_inferred):
|
||||
# Currently only 'random_id' needs 'os' to be imported
|
||||
builder.writeln('import os')
|
||||
|
||||
builder.writeln()
|
||||
builder.writeln()
|
||||
builder.writeln('class {}(TLObject):'.format(
|
||||
TLGenerator.get_class_name(tlobject)))
|
||||
builder.writeln('class {}(TLObject):'.format(tlobject.class_name()))
|
||||
|
||||
# Class-level variable to store its Telegram's constructor ID
|
||||
builder.writeln('constructor_id = {}'.format(hex(tlobject.id)))
|
||||
builder.writeln('subclass_of_id = {}'.format(
|
||||
hex(crc32(tlobject.result.encode('ascii')))))
|
||||
builder.writeln('CONSTRUCTOR_ID = {}'.format(hex(tlobject.id)))
|
||||
builder.writeln('SUBCLASS_OF_ID = {}'.format(
|
||||
hex(crc32(tlobject.result.encode('ascii'))))
|
||||
)
|
||||
builder.writeln()
|
||||
|
||||
# Flag arguments must go last
|
||||
|
@ -221,17 +217,10 @@ class TLGenerator:
|
|||
builder.writeln('"""')
|
||||
for arg in args:
|
||||
if not arg.flag_indicator:
|
||||
builder.write(
|
||||
':param {}: Telegram type: "{}".'
|
||||
.format(arg.name, arg.type)
|
||||
)
|
||||
if arg.is_vector:
|
||||
builder.write(' Must be a list.'.format(arg.name))
|
||||
|
||||
if arg.is_generic:
|
||||
builder.write(' Must be another TLObject request.')
|
||||
|
||||
builder.writeln()
|
||||
builder.writeln(':param {} {}:'.format(
|
||||
arg.type_hint(), arg.name
|
||||
))
|
||||
builder.current_indent -= 1 # It will auto-indent (':')
|
||||
|
||||
# We also want to know what type this request returns
|
||||
# or to which type this constructor belongs to
|
||||
|
@ -246,12 +235,11 @@ class TLGenerator:
|
|||
builder.writeln('This type has no constructors.')
|
||||
elif len(constructors) == 1:
|
||||
builder.writeln('Instance of {}.'.format(
|
||||
TLGenerator.get_class_name(constructors[0])
|
||||
constructors[0].class_name()
|
||||
))
|
||||
else:
|
||||
builder.writeln('Instance of either {}.'.format(
|
||||
', '.join(TLGenerator.get_class_name(c)
|
||||
for c in constructors)
|
||||
', '.join(c.class_name() for c in constructors)
|
||||
))
|
||||
|
||||
builder.writeln('"""')
|
||||
|
@ -274,56 +262,77 @@ class TLGenerator:
|
|||
builder.end_block()
|
||||
|
||||
# Write the to_dict(self) method
|
||||
builder.writeln('def to_dict(self, recursive=True):')
|
||||
if args:
|
||||
builder.writeln('def to_dict(self):')
|
||||
builder.writeln('return {')
|
||||
builder.current_indent += 1
|
||||
|
||||
base_types = ('string', 'bytes', 'int', 'long', 'int128',
|
||||
'int256', 'double', 'Bool', 'true', 'date')
|
||||
|
||||
for arg in args:
|
||||
builder.write("'{}': ".format(arg.name))
|
||||
if arg.type in base_types:
|
||||
if arg.is_vector:
|
||||
builder.write(
|
||||
'[] if self.{0} is None else self.{0}[:]'
|
||||
.format(arg.name)
|
||||
)
|
||||
else:
|
||||
builder.write('self.{}'.format(arg.name))
|
||||
else:
|
||||
if arg.is_vector:
|
||||
builder.write(
|
||||
'[] if self.{0} is None else [None '
|
||||
'if x is None else x.to_dict() for x in self.{0}]'
|
||||
.format(arg.name)
|
||||
)
|
||||
else:
|
||||
builder.write(
|
||||
'None if self.{0} is None else self.{0}.to_dict()'
|
||||
.format(arg.name)
|
||||
)
|
||||
builder.writeln(',')
|
||||
|
||||
builder.current_indent -= 1
|
||||
builder.writeln("}")
|
||||
else:
|
||||
builder.writeln('@staticmethod')
|
||||
builder.writeln('def to_dict():')
|
||||
builder.writeln('return {}')
|
||||
builder.write('return {')
|
||||
builder.current_indent += 1
|
||||
|
||||
base_types = ('string', 'bytes', 'int', 'long', 'int128',
|
||||
'int256', 'double', 'Bool', 'true', 'date')
|
||||
|
||||
for arg in args:
|
||||
builder.write("'{}': ".format(arg.name))
|
||||
if arg.type in base_types:
|
||||
if arg.is_vector:
|
||||
builder.write('[] if self.{0} is None else self.{0}[:]'
|
||||
.format(arg.name))
|
||||
else:
|
||||
builder.write('self.{}'.format(arg.name))
|
||||
else:
|
||||
if arg.is_vector:
|
||||
builder.write(
|
||||
'([] if self.{0} is None else [None'
|
||||
' if x is None else x.to_dict() for x in self.{0}]'
|
||||
') if recursive else self.{0}'.format(arg.name)
|
||||
)
|
||||
else:
|
||||
builder.write(
|
||||
'(None if self.{0} is None else self.{0}.to_dict())'
|
||||
' if recursive else self.{0}'.format(arg.name)
|
||||
)
|
||||
builder.writeln(',')
|
||||
|
||||
builder.current_indent -= 1
|
||||
builder.writeln("}")
|
||||
|
||||
builder.end_block()
|
||||
|
||||
# Write the on_send(self, writer) function
|
||||
builder.writeln('def on_send(self, writer):')
|
||||
builder.writeln(
|
||||
'writer.write_int({}.constructor_id, signed=False)'
|
||||
.format(TLGenerator.get_class_name(tlobject)))
|
||||
# Write the .to_bytes() function
|
||||
builder.writeln('def to_bytes(self):')
|
||||
|
||||
# Some objects require more than one flag parameter to be set
|
||||
# at the same time. In this case, add an assertion.
|
||||
repeated_args = defaultdict(list)
|
||||
for arg in tlobject.args:
|
||||
if arg.is_flag:
|
||||
repeated_args[arg.flag_index].append(arg)
|
||||
|
||||
for ra in repeated_args.values():
|
||||
if len(ra) > 1:
|
||||
cnd1 = ('self.{} is None'.format(a.name) for a in ra)
|
||||
cnd2 = ('self.{} is not None'.format(a.name) for a in ra)
|
||||
builder.writeln(
|
||||
"assert ({}) or ({}), '{} parameters must all "
|
||||
"be None or neither be None'".format(
|
||||
' and '.join(cnd1), ' and '.join(cnd2),
|
||||
', '.join(a.name for a in ra)
|
||||
)
|
||||
)
|
||||
|
||||
builder.writeln("return b''.join((")
|
||||
builder.current_indent += 1
|
||||
|
||||
# First constructor code, we already know its bytes
|
||||
builder.writeln('{},'.format(repr(struct.pack('<I', tlobject.id))))
|
||||
|
||||
for arg in tlobject.args:
|
||||
TLGenerator.write_onsend_code(builder, arg,
|
||||
tlobject.args)
|
||||
if TLGenerator.write_to_bytes(builder, arg, tlobject.args):
|
||||
builder.writeln(',')
|
||||
|
||||
builder.current_indent -= 1
|
||||
builder.writeln('))')
|
||||
builder.end_block()
|
||||
|
||||
# Write the empty() function, which returns an "empty"
|
||||
|
@ -331,8 +340,8 @@ class TLGenerator:
|
|||
builder.writeln('@staticmethod')
|
||||
builder.writeln('def empty():')
|
||||
builder.writeln('return {}({})'.format(
|
||||
TLGenerator.get_class_name(tlobject), ', '.join(
|
||||
'None' for _ in range(len(args)))))
|
||||
tlobject.class_name(), ', '.join('None' for _ in range(len(args)))
|
||||
))
|
||||
builder.end_block()
|
||||
|
||||
# Write the on_response(self, reader) function
|
||||
|
@ -345,18 +354,15 @@ class TLGenerator:
|
|||
if tlobject.args:
|
||||
for arg in tlobject.args:
|
||||
TLGenerator.write_onresponse_code(
|
||||
builder, arg, tlobject.args)
|
||||
builder, arg, tlobject.args
|
||||
)
|
||||
else:
|
||||
# If there were no arguments, we still need an
|
||||
# on_response method, and hence "pass" if empty
|
||||
builder.writeln('pass')
|
||||
builder.end_block()
|
||||
|
||||
# Write the __repr__(self) and __str__(self) functions
|
||||
builder.writeln('def __repr__(self):')
|
||||
builder.writeln("return '{}'".format(repr(tlobject)))
|
||||
builder.end_block()
|
||||
|
||||
# Write the __str__(self) and stringify(self) functions
|
||||
builder.writeln('def __str__(self):')
|
||||
builder.writeln('return TLObject.pretty_format(self)')
|
||||
builder.end_block()
|
||||
|
@ -398,6 +404,8 @@ class TLGenerator:
|
|||
TLGenerator.write_get_input(builder, arg, 'get_input_channel')
|
||||
elif arg.type == 'InputUser' and tlobject.is_function:
|
||||
TLGenerator.write_get_input(builder, arg, 'get_input_user')
|
||||
elif arg.type == 'InputMedia' and tlobject.is_function:
|
||||
TLGenerator.write_get_input(builder, arg, 'get_input_media')
|
||||
|
||||
else:
|
||||
builder.writeln('self.{0} = {0}'.format(arg.name))
|
||||
|
@ -408,29 +416,14 @@ class TLGenerator:
|
|||
a parameter upon creating the request. Returns False otherwise
|
||||
"""
|
||||
if arg.is_vector:
|
||||
builder.writeln(
|
||||
'self.{0} = [{1}(_x) for _x in {0}]'
|
||||
.format(arg.name, get_input_code)
|
||||
)
|
||||
pass
|
||||
builder.write('self.{0} = [{1}(_x) for _x in {0}]'
|
||||
.format(arg.name, get_input_code))
|
||||
else:
|
||||
builder.writeln(
|
||||
'self.{0} = {1}({0})'.format(arg.name, get_input_code)
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_class_name(tlobject):
|
||||
"""Gets the class name following the Python style guidelines"""
|
||||
|
||||
# Courtesy of http://stackoverflow.com/a/31531797/4759433
|
||||
result = re.sub(r'_([a-z])', lambda m: m.group(1).upper(),
|
||||
tlobject.name)
|
||||
result = result[:1].upper() + result[1:].replace(
|
||||
'_', '') # Replace again to fully ensure!
|
||||
# If it's a function, let it end with "Request" to identify them
|
||||
if tlobject.is_function:
|
||||
result += 'Request'
|
||||
return result
|
||||
builder.write('self.{0} = {1}({0})'
|
||||
.format(arg.name, get_input_code))
|
||||
builder.writeln(
|
||||
' if {} else None'.format(arg.name) if arg.is_flag else ''
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_file_name(tlobject, add_extension=False):
|
||||
|
@ -445,18 +438,17 @@ class TLGenerator:
|
|||
return result
|
||||
|
||||
@staticmethod
|
||||
def write_onsend_code(builder, arg, args, name=None):
|
||||
def write_to_bytes(builder, arg, args, name=None):
|
||||
"""
|
||||
Writes the write code for the given argument
|
||||
Writes the .to_bytes() code for the given argument
|
||||
:param builder: The source code builder
|
||||
:param arg: The argument to write
|
||||
:param args: All the other arguments in TLObject same on_send.
|
||||
:param args: All the other arguments in TLObject same to_bytes.
|
||||
This is required to determine the flags value
|
||||
:param name: The name of the argument. Defaults to "self.argname"
|
||||
This argument is an option because it's required when
|
||||
writing Vectors<>
|
||||
"""
|
||||
|
||||
if arg.generic_definition:
|
||||
return # Do nothing, this only specifies a later type
|
||||
|
||||
|
@ -470,73 +462,91 @@ class TLGenerator:
|
|||
if arg.is_flag:
|
||||
if arg.type == 'true':
|
||||
return # Exit, since True type is never written
|
||||
elif arg.is_vector:
|
||||
# Vector flags are special since they consist of 3 values,
|
||||
# so we need an extra join here. Note that empty vector flags
|
||||
# should NOT be sent either!
|
||||
builder.write("b'' if not {} else b''.join((".format(name))
|
||||
else:
|
||||
builder.writeln('if {}:'.format(name))
|
||||
builder.write("b'' if not {} else (".format(name))
|
||||
|
||||
if arg.is_vector:
|
||||
if arg.use_vector_id:
|
||||
builder.writeln('writer.write_int(0x1cb5c415, signed=False)')
|
||||
# vector code, unsigned 0x1cb5c415 as little endian
|
||||
builder.write(r"b'\x15\xc4\xb5\x1c',")
|
||||
|
||||
builder.write("struct.pack('<i', len({})),".format(name))
|
||||
|
||||
# Cannot unpack the values for the outer tuple through *[(
|
||||
# since that's a Python >3.5 feature, so add another join.
|
||||
builder.write("b''.join(")
|
||||
|
||||
builder.writeln('writer.write_int(len({}))'.format(name))
|
||||
builder.writeln('for _x in {}:'.format(name))
|
||||
# Temporary disable .is_vector, not to enter this if again
|
||||
arg.is_vector = False
|
||||
TLGenerator.write_onsend_code(builder, arg, args, name='_x')
|
||||
# Also disable .is_flag since it's not needed per element
|
||||
old_flag = arg.is_flag
|
||||
arg.is_vector = arg.is_flag = False
|
||||
TLGenerator.write_to_bytes(builder, arg, args, name='x')
|
||||
arg.is_vector = True
|
||||
arg.is_flag = old_flag
|
||||
|
||||
builder.write(' for x in {})'.format(name))
|
||||
|
||||
elif arg.flag_indicator:
|
||||
# Calculate the flags with those items which are not None
|
||||
builder.writeln('flags = 0')
|
||||
for flag in args:
|
||||
if flag.is_flag:
|
||||
builder.writeln('flags |= (1 << {}) if {} else 0'.format(
|
||||
flag.flag_index, 'self.{}'.format(flag.name)))
|
||||
|
||||
builder.writeln('writer.write_int(flags)')
|
||||
builder.writeln()
|
||||
builder.write("struct.pack('<I', {})".format(
|
||||
' | '.join('({} if {} else 0)'.format(
|
||||
1 << flag.flag_index, 'self.{}'.format(flag.name)
|
||||
) for flag in args if flag.is_flag)
|
||||
))
|
||||
|
||||
elif 'int' == arg.type:
|
||||
builder.writeln('writer.write_int({})'.format(name))
|
||||
# struct.pack is around 4 times faster than int.to_bytes
|
||||
builder.write("struct.pack('<i', {})".format(name))
|
||||
|
||||
elif 'long' == arg.type:
|
||||
builder.writeln('writer.write_long({})'.format(name))
|
||||
builder.write("struct.pack('<q', {})".format(name))
|
||||
|
||||
elif 'int128' == arg.type:
|
||||
builder.writeln('writer.write_large_int({}, bits=128)'.format(
|
||||
name))
|
||||
builder.write("{}.to_bytes(16, 'little', signed=True)".format(name))
|
||||
|
||||
elif 'int256' == arg.type:
|
||||
builder.writeln('writer.write_large_int({}, bits=256)'.format(
|
||||
name))
|
||||
builder.write("{}.to_bytes(32, 'little', signed=True)".format(name))
|
||||
|
||||
elif 'double' == arg.type:
|
||||
builder.writeln('writer.write_double({})'.format(name))
|
||||
builder.write("struct.pack('<d', {})".format(name))
|
||||
|
||||
elif 'string' == arg.type:
|
||||
builder.writeln('writer.tgwrite_string({})'.format(name))
|
||||
builder.write('TLObject.serialize_bytes({})'.format(name))
|
||||
|
||||
elif 'Bool' == arg.type:
|
||||
builder.writeln('writer.tgwrite_bool({})'.format(name))
|
||||
# 0x997275b5 if boolean else 0xbc799737
|
||||
builder.write(
|
||||
r"b'\xb5ur\x99' if {} else b'7\x97y\xbc'".format(name)
|
||||
)
|
||||
|
||||
elif 'true' == arg.type:
|
||||
pass # These are actually NOT written! Only used for flags
|
||||
|
||||
elif 'bytes' == arg.type:
|
||||
builder.writeln('writer.tgwrite_bytes({})'.format(name))
|
||||
builder.write('TLObject.serialize_bytes({})'.format(name))
|
||||
|
||||
elif 'date' == arg.type: # Custom format
|
||||
builder.writeln('writer.tgwrite_date({})'.format(name))
|
||||
# 0 if datetime is None else int(datetime.timestamp())
|
||||
builder.write(
|
||||
r"b'\0\0\0\0' if {0} is None else "
|
||||
r"struct.pack('<I', int({0}.timestamp()))".format(name)
|
||||
)
|
||||
|
||||
else:
|
||||
# Else it may be a custom type
|
||||
builder.writeln('{}.on_send(writer)'.format(name))
|
||||
|
||||
# End vector and flag blocks if required (if we opened them before)
|
||||
if arg.is_vector:
|
||||
builder.end_block()
|
||||
builder.write('{}.to_bytes()'.format(name))
|
||||
|
||||
if arg.is_flag:
|
||||
builder.end_block()
|
||||
builder.write(')')
|
||||
if arg.is_vector:
|
||||
builder.write(')') # We were using a tuple
|
||||
|
||||
return True # Something was written
|
||||
|
||||
@staticmethod
|
||||
def write_onresponse_code(builder, arg, args, name=None):
|
||||
|
@ -562,8 +572,8 @@ class TLGenerator:
|
|||
was_flag = False
|
||||
if arg.is_flag:
|
||||
was_flag = True
|
||||
builder.writeln('if (flags & (1 << {})) != 0:'.format(
|
||||
arg.flag_index
|
||||
builder.writeln('if flags & {}:'.format(
|
||||
1 << arg.flag_index
|
||||
))
|
||||
# Temporary disable .is_flag not to enter this if
|
||||
# again when calling the method recursively
|
||||
|
|
|
@ -1,90 +1,61 @@
|
|||
import os
|
||||
import unittest
|
||||
from telethon.extensions import BinaryReader, BinaryWriter
|
||||
from telethon.tl import TLObject
|
||||
from telethon.extensions import BinaryReader
|
||||
|
||||
|
||||
class UtilsTests(unittest.TestCase):
|
||||
@staticmethod
|
||||
def test_binary_writer_reader():
|
||||
# Test that we can write and read properly
|
||||
with BinaryWriter() as writer:
|
||||
writer.write_byte(1)
|
||||
writer.write_int(5)
|
||||
writer.write_long(13)
|
||||
writer.write_float(17.0)
|
||||
writer.write_double(25.0)
|
||||
writer.write(bytes([26, 27, 28, 29, 30, 31, 32]))
|
||||
writer.write_large_int(2**127, 128, signed=False)
|
||||
|
||||
data = writer.get_bytes()
|
||||
expected = b'\x01\x05\x00\x00\x00\r\x00\x00\x00\x00\x00\x00\x00\x00\x00\x88A\x00\x00\x00\x00\x00\x00' \
|
||||
b'9@\x1a\x1b\x1c\x1d\x1e\x1f \x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x80'
|
||||
|
||||
assert data == expected, 'Retrieved data does not match the expected value'
|
||||
# Test that we can read properly
|
||||
data = b'\x01\x05\x00\x00\x00\r\x00\x00\x00\x00\x00\x00\x00\x00\x00' \
|
||||
b'\x88A\x00\x00\x00\x00\x00\x009@\x1a\x1b\x1c\x1d\x1e\x1f ' \
|
||||
b'\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00' \
|
||||
b'\x00\x80'
|
||||
|
||||
with BinaryReader(data) as reader:
|
||||
value = reader.read_byte()
|
||||
assert value == 1, 'Example byte should be 1 but is {}'.format(
|
||||
value)
|
||||
assert value == 1, 'Example byte should be 1 but is {}'.format(value)
|
||||
|
||||
value = reader.read_int()
|
||||
assert value == 5, 'Example integer should be 5 but is {}'.format(
|
||||
value)
|
||||
assert value == 5, 'Example integer should be 5 but is {}'.format(value)
|
||||
|
||||
value = reader.read_long()
|
||||
assert value == 13, 'Example long integer should be 13 but is {}'.format(
|
||||
value)
|
||||
assert value == 13, 'Example long integer should be 13 but is {}'.format(value)
|
||||
|
||||
value = reader.read_float()
|
||||
assert value == 17.0, 'Example float should be 17.0 but is {}'.format(
|
||||
value)
|
||||
assert value == 17.0, 'Example float should be 17.0 but is {}'.format(value)
|
||||
|
||||
value = reader.read_double()
|
||||
assert value == 25.0, 'Example double should be 25.0 but is {}'.format(
|
||||
value)
|
||||
assert value == 25.0, 'Example double should be 25.0 but is {}'.format(value)
|
||||
|
||||
value = reader.read(7)
|
||||
assert value == bytes([26, 27, 28, 29, 30, 31, 32]), 'Example bytes should be {} but is {}' \
|
||||
.format(bytes([26, 27, 28, 29, 30, 31, 32]), value)
|
||||
|
||||
value = reader.read_large_int(128, signed=False)
|
||||
assert value == 2**127, 'Example large integer should be {} but is {}'.format(
|
||||
2**127, value)
|
||||
|
||||
# Test Telegram that types are written right
|
||||
with BinaryWriter() as writer:
|
||||
writer.write_int(0x60469778)
|
||||
buffer = writer.get_bytes()
|
||||
valid = b'\x78\x97\x46\x60' # Tested written bytes using C#'s MemoryStream
|
||||
|
||||
assert buffer == valid, 'Written type should be {} but is {}'.format(
|
||||
list(valid), list(buffer))
|
||||
assert value == 2**127, 'Example large integer should be {} but is {}'.format(2**127, value)
|
||||
|
||||
@staticmethod
|
||||
def test_binary_tgwriter_tgreader():
|
||||
small_data = os.urandom(33)
|
||||
small_data_padded = os.urandom(
|
||||
19) # +1 byte for length = 20 (evenly divisible by 4)
|
||||
small_data_padded = os.urandom(19) # +1 byte for length = 20 (%4 = 0)
|
||||
|
||||
large_data = os.urandom(999)
|
||||
large_data_padded = os.urandom(1024)
|
||||
|
||||
data = (small_data, small_data_padded, large_data, large_data_padded)
|
||||
string = 'Testing Telegram strings, this should work properly!'
|
||||
serialized = b''.join(TLObject.serialize_bytes(d) for d in data) + \
|
||||
TLObject.serialize_bytes(string)
|
||||
|
||||
with BinaryWriter() as writer:
|
||||
# First write the data
|
||||
with BinaryReader(serialized) as reader:
|
||||
# And then try reading it without errors (it should be unharmed!)
|
||||
for datum in data:
|
||||
writer.tgwrite_bytes(datum)
|
||||
writer.tgwrite_string(string)
|
||||
value = reader.tgread_bytes()
|
||||
assert value == datum, 'Example bytes should be {} but is {}'.format(
|
||||
datum, value)
|
||||
|
||||
with BinaryReader(writer.get_bytes()) as reader:
|
||||
# And then try reading it without errors (it should be unharmed!)
|
||||
for datum in data:
|
||||
value = reader.tgread_bytes()
|
||||
assert value == datum, 'Example bytes should be {} but is {}'.format(
|
||||
datum, value)
|
||||
|
||||
value = reader.tgread_string()
|
||||
assert value == string, 'Example string should be {} but is {}'.format(
|
||||
string, value)
|
||||
value = reader.tgread_string()
|
||||
assert value == string, 'Example string should be {} but is {}'.format(
|
||||
string, value)
|
||||
|
|
Loading…
Reference in New Issue
Block a user