diff --git a/.github/ISSUE_TEMPLATE.md b/.github/ISSUE_TEMPLATE.md index c9b27a9f..eb08ea77 100644 --- a/.github/ISSUE_TEMPLATE.md +++ b/.github/ISSUE_TEMPLATE.md @@ -1,26 +1,9 @@ - -### What went wrong -Describe what happened or what the error you have is. - -``` -// paste the crash log here if any -``` - -### What I've done -Either a code example of what you were trying to do, or steps to reproduce, or methods you have tried invoking. - -```python -# Add your Python code here -``` - -### More information -If you think other information can be relevant (e.g. operative system or other variables), add it here. diff --git a/docs/generate.py b/docs/generate.py index 80408a47..4feb1518 100755 --- a/docs/generate.py +++ b/docs/generate.py @@ -199,14 +199,27 @@ def generate_index(folder, original_paths): def get_description(arg): """Generates a proper description for the given argument""" desc = [] + otherwise = False if arg.can_be_inferred: - desc.append('If left to None, it will be inferred automatically.') - if arg.is_vector: - desc.append('A list must be supplied for this argument.') - if arg.is_generic: - desc.append('A different Request must be supplied for this argument.') - if arg.is_flag: + desc.append('If left unspecified, it will be inferred automatically.') + otherwise = True + elif arg.is_flag: desc.append('This argument can be omitted.') + otherwise = True + + if arg.is_vector: + if arg.is_generic: + desc.append('A list of other Requests must be supplied.') + else: + desc.append('A list must be supplied.') + elif arg.is_generic: + desc.append('A different Request must be supplied for this argument.') + else: + otherwise = False # Always reset to false if no other text is added + + if otherwise: + desc.insert(1, 'Otherwise,') + desc[-1] = desc[-1][:1].lower() + desc[-1][1:] return ' '.join(desc) @@ -218,6 +231,7 @@ def generate_documentation(scheme_file): original_paths = { 'css': 'css/docs.css', 'arrow': 'img/arrow.svg', + '404': '404.html', 'index_all': 'index.html', 'index_types': 'types/index.html', 'index_methods': 'methods/index.html', @@ -360,7 +374,8 @@ def generate_documentation(scheme_file): for tltype, constructors in tltypes.items(): filename = get_path_for_type(tltype) out_dir = os.path.dirname(filename) - os.makedirs(out_dir, exist_ok=True) + if out_dir: + os.makedirs(out_dir, exist_ok=True) # Since we don't have access to the full TLObject, split the type if '.' in tltype: @@ -503,15 +518,26 @@ def generate_documentation(scheme_file): methods = sorted(methods, key=lambda m: m.name) constructors = sorted(constructors, key=lambda c: c.name) + def fmt(xs): + ys = {x: get_class_name(x) for x in xs} # cache TLObject: display + zs = {} # create a dict to hold those which have duplicated keys + for y in ys.values(): + zs[y] = y in zs + return ', '.join( + '"{}.{}"'.format(x.namespace, ys[x]) + if zs[ys[x]] and getattr(x, 'namespace', None) + else '"{}"'.format(ys[x]) for x in xs + ) + + request_names = fmt(methods) + type_names = fmt(types) + constructor_names = fmt(constructors) + def fmt(xs, formatter): return ', '.join('"{}"'.format(formatter(x)) for x in xs) - request_names = fmt(methods, get_class_name) - type_names = fmt(types, get_class_name) - constructor_names = fmt(constructors, get_class_name) - request_urls = fmt(methods, get_create_path_for) - type_urls = fmt(types, get_create_path_for) + type_urls = fmt(types, get_path_for_type) constructor_urls = fmt(constructors, get_create_path_for) replace_dict = { @@ -528,13 +554,15 @@ def generate_documentation(scheme_file): 'constructor_urls': constructor_urls } - with open('../res/core.html') as infile: - with open(original_paths['index_all'], 'w') as outfile: - text = infile.read() - for key, value in replace_dict.items(): - text = text.replace('{' + key + '}', str(value)) + shutil.copy('../res/404.html', original_paths['404']) - outfile.write(text) + with open('../res/core.html') as infile,\ + open(original_paths['index_all'], 'w') as outfile: + text = infile.read() + for key, value in replace_dict.items(): + text = text.replace('{' + key + '}', str(value)) + + outfile.write(text) # Everything done print('Documentation generated.') @@ -551,5 +579,8 @@ def copy_resources(): if __name__ == '__main__': os.makedirs('generated', exist_ok=True) os.chdir('generated') - generate_documentation('../../telethon_generator/scheme.tl') - copy_resources() + try: + generate_documentation('../../telethon_generator/scheme.tl') + copy_resources() + finally: + os.chdir(os.pardir) diff --git a/docs/res/404.html b/docs/res/404.html new file mode 100644 index 00000000..8eb3d37d --- /dev/null +++ b/docs/res/404.html @@ -0,0 +1,44 @@ + + + Oopsie! | Telethon + + + + + + + +
+

You seem a bit lost…

+

You seem to be lost! Don't worry, that's just Telegram's API being + itself. Shall we go back to the Main Page?

+
+ + diff --git a/setup.py b/setup.py index 679d068b..695ad1a5 100755 --- a/setup.py +++ b/setup.py @@ -12,34 +12,53 @@ Extra supported commands are: """ # To use a consistent encoding -from subprocess import run -from shutil import rmtree from codecs import open from sys import argv -from os import path +import os # Always prefer setuptools over distutils from setuptools import find_packages, setup try: from telethon import TelegramClient -except ImportError: +except Exception as e: + print('Failed to import TelegramClient due to', e) TelegramClient = None -if __name__ == '__main__': - if len(argv) >= 2 and argv[1] == 'gen_tl': - from telethon_generator.tl_generator import TLGenerator - generator = TLGenerator('telethon/tl') - if generator.tlobjects_exist(): - print('Detected previous TLObjects. Cleaning...') - generator.clean_tlobjects() +class TempWorkDir: + """Switches the working directory to be the one on which this file lives, + while within the 'with' block. + """ + def __init__(self): + self.original = None - print('Generating TLObjects...') - generator.generate_tlobjects( - 'telethon_generator/scheme.tl', import_depth=2 - ) - print('Done.') + def __enter__(self): + self.original = os.path.abspath(os.path.curdir) + os.chdir(os.path.abspath(os.path.dirname(__file__))) + return self + + def __exit__(self, *args): + os.chdir(self.original) + + +def gen_tl(): + from telethon_generator.tl_generator import TLGenerator + generator = TLGenerator('telethon/tl') + if generator.tlobjects_exist(): + print('Detected previous TLObjects. Cleaning...') + generator.clean_tlobjects() + + print('Generating TLObjects...') + generator.generate_tlobjects( + 'telethon_generator/scheme.tl', import_depth=2 + ) + print('Done.') + + +def main(): + if len(argv) >= 2 and argv[1] == 'gen_tl': + gen_tl() elif len(argv) >= 2 and argv[1] == 'clean_tl': from telethon_generator.tl_generator import TLGenerator @@ -48,6 +67,11 @@ if __name__ == '__main__': print('Done.') elif len(argv) >= 2 and argv[1] == 'pypi': + # Need python3.5 or higher, but Telethon is supposed to support 3.x + # Place it here since noone should be running ./setup.py pypi anyway + from subprocess import run + from shutil import rmtree + for x in ('build', 'dist', 'Telethon.egg-info'): rmtree(x, ignore_errors=True) run('python3 setup.py sdist', shell=True) @@ -58,20 +82,21 @@ if __name__ == '__main__': else: if not TelegramClient: - print('Run `python3', argv[0], 'gen_tl` first.') - quit() - - here = path.abspath(path.dirname(__file__)) + gen_tl() + from telethon import TelegramClient as TgClient + version = TgClient.__version__ + else: + version = TelegramClient.__version__ # Get the long description from the README file - with open(path.join(here, 'README.rst'), encoding='utf-8') as f: + with open('README.rst', encoding='utf-8') as f: long_description = f.read() setup( name='Telethon', # Versions should comply with PEP440. - version=TelegramClient.__version__, + version=version, description="Full-featured Telegram client library for Python 3", long_description=long_description, @@ -108,3 +133,8 @@ if __name__ == '__main__': ]), install_requires=['pyaes', 'rsa'] ) + + +if __name__ == '__main__': + with TempWorkDir(): # Could just use a try/finally but this is + reusable + main() diff --git a/telethon/crypto/auth_key.py b/telethon/crypto/auth_key.py index 02774d58..17a7f8ca 100644 --- a/telethon/crypto/auth_key.py +++ b/telethon/crypto/auth_key.py @@ -1,7 +1,8 @@ +import struct from hashlib import sha1 from .. import helpers as utils -from ..extensions import BinaryReader, BinaryWriter +from ..extensions import BinaryReader class AuthKey: @@ -17,10 +18,6 @@ class AuthKey: """Calculates the new nonce hash based on the current class fields' values """ - with BinaryWriter() as writer: - writer.write(new_nonce) - writer.write_byte(number) - writer.write_long(self.aux_hash, signed=False) - - new_nonce_hash = utils.calc_msg_key(writer.get_bytes()) - return new_nonce_hash + new_nonce = new_nonce.to_bytes(32, 'little', signed=True) + data = new_nonce + struct.pack('> 8) % 256])) - self.write(bytes([(len(data) >> 16) % 256])) - self.write(data) - - self.write(bytes(padding)) - - def tgwrite_string(self, string): - """Write a string by using Telegram guidelines""" - self.tgwrite_bytes(string.encode('utf-8')) - - def tgwrite_bool(self, boolean): - """Write a boolean value by using Telegram guidelines""" - # boolTrue boolFalse - self.write_int(0x997275b5 if boolean else 0xbc799737, signed=False) - - def tgwrite_date(self, datetime): - """Converts a Python datetime object into Unix time - (used by Telegram) and writes it - """ - value = 0 if datetime is None else int(datetime.timestamp()) - self.write_int(value) - - def tgwrite_object(self, tlobject): - """Writes a Telegram object""" - tlobject.on_send(self) - - def tgwrite_vector(self, vector): - """Writes a vector of Telegram objects""" - self.write_int(0x1cb5c415, signed=False) # Vector's constructor ID - self.write_int(len(vector)) - for item in vector: - self.tgwrite_object(item) - - # endregion - - def flush(self): - """Flush the current stream to "update" changes""" - self.writer.flush() - - def close(self): - """Close the current stream""" - self.writer.close() - - def get_bytes(self, flush=True): - """Get the current bytes array content from the buffer, - optionally flushing first - """ - if flush: - self.writer.flush() - return self.writer.raw.getvalue() - - def get_written_bytes_count(self): - """Gets the count of bytes written in the buffer. - This may NOT be equal to the stream length if one - was provided when initializing the writer - """ - return self.written_count - - # with block - def __enter__(self): - return self - - def __exit__(self, exc_type, exc_val, exc_tb): - self.close() diff --git a/telethon/extensions/tcp_client.py b/telethon/extensions/tcp_client.py index 8453af5e..6feb9841 100644 --- a/telethon/extensions/tcp_client.py +++ b/telethon/extensions/tcp_client.py @@ -8,41 +8,55 @@ from threading import Lock class TcpClient: def __init__(self, proxy=None, timeout=timedelta(seconds=5)): - self._proxy = proxy + self.proxy = proxy self._socket = None self._closing_lock = Lock() if isinstance(timeout, timedelta): - self._timeout = timeout.seconds + self.timeout = timeout.seconds elif isinstance(timeout, int) or isinstance(timeout, float): - self._timeout = float(timeout) + self.timeout = float(timeout) else: raise ValueError('Invalid timeout type', type(timeout)) def _recreate_socket(self, mode): - if self._proxy is None: + if self.proxy is None: self._socket = socket.socket(mode, socket.SOCK_STREAM) else: import socks self._socket = socks.socksocket(mode, socket.SOCK_STREAM) - if type(self._proxy) is dict: - self._socket.set_proxy(**self._proxy) + if type(self.proxy) is dict: + self._socket.set_proxy(**self.proxy) else: # tuple, list, etc. - self._socket.set_proxy(*self._proxy) + self._socket.set_proxy(*self.proxy) + + self._socket.settimeout(self.timeout) def connect(self, ip, port): """Connects to the specified IP and port number. 'timeout' must be given in seconds """ - if not self.connected: - if ':' in ip: # IPv6 - mode, address = socket.AF_INET6, (ip, port, 0, 0) - else: - mode, address = socket.AF_INET, (ip, port) + if ':' in ip: # IPv6 + mode, address = socket.AF_INET6, (ip, port, 0, 0) + else: + mode, address = socket.AF_INET, (ip, port) - self._recreate_socket(mode) - self._socket.settimeout(self._timeout) - self._socket.connect(address) + while True: + try: + while not self._socket: + self._recreate_socket(mode) + + self._socket.connect(address) + break # Successful connection, stop retrying to connect + except OSError as e: + # There are some errors that we know how to handle, and + # the loop will allow us to retry + if e.errno == errno.EBADF: + # Bad file descriptor, i.e. socket was closed, set it + # to none to recreate it on the next iteration + self._socket = None + else: + raise def _get_connected(self): return self._socket is not None @@ -67,6 +81,8 @@ class TcpClient: def write(self, data): """Writes (sends) the specified bytes to the connected peer""" + if self._socket is None: + raise ConnectionResetError() # TODO Timeout may be an issue when sending the data, Changed in v3.5: # The socket timeout is now the maximum total duration to send all data. @@ -74,13 +90,13 @@ class TcpClient: self._socket.sendall(data) except socket.timeout as e: raise TimeoutError() from e + except BrokenPipeError: + self._raise_connection_reset() except OSError as e: if e.errno == errno.EBADF: self._raise_connection_reset() else: raise - except BrokenPipeError: - self._raise_connection_reset() def read(self, size): """Reads (receives) a whole block of 'size bytes @@ -91,6 +107,9 @@ class TcpClient: and it's waiting for more, the timeout will NOT cancel the operation. Set to None for no timeout """ + if self._socket is None: + raise ConnectionResetError() + # TODO Remove the timeout from this method, always use previous one with BufferedWriter(BytesIO(), buffer_size=size) as buffer: bytes_left = size @@ -100,7 +119,7 @@ class TcpClient: except socket.timeout as e: raise TimeoutError() from e except OSError as e: - if e.errno == errno.EBADF: + if e.errno == errno.EBADF or e.errno == errno.ENOTSOCK: self._raise_connection_reset() else: raise diff --git a/telethon/helpers.py b/telethon/helpers.py index 3212bf4b..3c9af2cb 100644 --- a/telethon/helpers.py +++ b/telethon/helpers.py @@ -47,9 +47,11 @@ def calc_msg_key(data): def generate_key_data_from_nonce(server_nonce, new_nonce): """Generates the key data corresponding to the given nonce""" - hash1 = sha1(bytes(new_nonce + server_nonce)).digest() - hash2 = sha1(bytes(server_nonce + new_nonce)).digest() - hash3 = sha1(bytes(new_nonce + new_nonce)).digest() + server_nonce = server_nonce.to_bytes(16, 'little', signed=True) + new_nonce = new_nonce.to_bytes(32, 'little', signed=True) + hash1 = sha1(new_nonce + server_nonce).digest() + hash2 = sha1(server_nonce + new_nonce).digest() + hash3 = sha1(new_nonce + new_nonce).digest() key = hash1 + hash2[:12] iv = hash2[12:20] + hash3 + new_nonce[:4] diff --git a/telethon/network/authenticator.py b/telethon/network/authenticator.py index 7d600fb3..1081897a 100644 --- a/telethon/network/authenticator.py +++ b/telethon/network/authenticator.py @@ -2,12 +2,19 @@ import os import time from hashlib import sha1 +from ..tl.types import ( + ResPQ, PQInnerData, ServerDHParamsFail, ServerDHParamsOk, + ServerDHInnerData, ClientDHInnerData, DhGenOk, DhGenRetry, DhGenFail +) from .. import helpers as utils from ..crypto import AES, AuthKey, Factorization from ..crypto import rsa -from ..errors import SecurityError, TypeNotFoundError -from ..extensions import BinaryReader, BinaryWriter +from ..errors import SecurityError +from ..extensions import BinaryReader from ..network import MtProtoPlainSender +from ..tl.functions import ( + ReqPqRequest, ReqDHParamsRequest, SetClientDHParamsRequest +) def do_authentication(connection, retries=5): @@ -18,7 +25,7 @@ def do_authentication(connection, retries=5): while retries: try: return _do_authentication(connection) - except (SecurityError, TypeNotFoundError, NotImplementedError) as e: + except (SecurityError, AssertionError, NotImplementedError) as e: last_error = e retries -= 1 raise last_error @@ -30,202 +37,158 @@ def _do_authentication(connection): time offset. """ sender = MtProtoPlainSender(connection) - sender.connect() - # Step 1 sending: PQ Request - nonce = os.urandom(16) - with BinaryWriter(known_length=20) as writer: - writer.write_int(0x60469778, signed=False) # Constructor number - writer.write(nonce) - sender.send(writer.get_bytes()) - - # Step 1 response: PQ Request - pq, pq_bytes, server_nonce, fingerprints = None, None, None, [] + # Step 1 sending: PQ Request, endianness doesn't matter since it's random + req_pq_request = ReqPqRequest( + nonce=int.from_bytes(os.urandom(16), 'big', signed=True) + ) + sender.send(req_pq_request.to_bytes()) with BinaryReader(sender.receive()) as reader: - response_code = reader.read_int(signed=False) - if response_code != 0x05162463: - raise TypeNotFoundError(response_code) + req_pq_request.on_response(reader) - nonce_from_server = reader.read(16) - if nonce_from_server != nonce: - raise SecurityError('Invalid nonce from server') + res_pq = req_pq_request.result + if not isinstance(res_pq, ResPQ): + raise AssertionError(res_pq) - server_nonce = reader.read(16) + if res_pq.nonce != req_pq_request.nonce: + raise SecurityError('Invalid nonce from server') - pq_bytes = reader.tgread_bytes() - pq = get_int(pq_bytes) - - vector_id = reader.read_int() - if vector_id != 0x1cb5c415: - raise TypeNotFoundError(response_code) - - fingerprints = [] - fingerprint_count = reader.read_int() - for _ in range(fingerprint_count): - fingerprints.append(reader.read(8)) + pq = get_int(res_pq.pq) # Step 2 sending: DH Exchange - new_nonce = os.urandom(32) p, q = Factorization.factorize(pq) - with BinaryWriter() as pq_inner_data_writer: - pq_inner_data_writer.write_int( - 0x83c95aec, signed=False) # PQ Inner Data - pq_inner_data_writer.tgwrite_bytes(rsa.get_byte_array(pq)) - pq_inner_data_writer.tgwrite_bytes(rsa.get_byte_array(min(p, q))) - pq_inner_data_writer.tgwrite_bytes(rsa.get_byte_array(max(p, q))) - pq_inner_data_writer.write(nonce) - pq_inner_data_writer.write(server_nonce) - pq_inner_data_writer.write(new_nonce) + p, q = rsa.get_byte_array(min(p, q)), rsa.get_byte_array(max(p, q)) + new_nonce = int.from_bytes(os.urandom(32), 'little', signed=True) - # sha_digest + data + random_bytes - cipher_text, target_fingerprint = None, None - for fingerprint in fingerprints: - cipher_text = rsa.encrypt( - fingerprint, - pq_inner_data_writer.get_bytes() + pq_inner_data = PQInnerData( + pq=rsa.get_byte_array(pq), p=p, q=q, + nonce=res_pq.nonce, + server_nonce=res_pq.server_nonce, + new_nonce=new_nonce + ).to_bytes() + + # sha_digest + data + random_bytes + cipher_text, target_fingerprint = None, None + for fingerprint in res_pq.server_public_key_fingerprints: + cipher_text = rsa.encrypt(fingerprint, pq_inner_data) + if cipher_text is not None: + target_fingerprint = fingerprint + break + + if cipher_text is None: + raise SecurityError( + 'Could not find a valid key for fingerprints: {}' + .format(', '.join( + [str(f) for f in res_pq.server_public_key_fingerprints]) ) + ) - if cipher_text is not None: - target_fingerprint = fingerprint - break - - if cipher_text is None: - raise SecurityError( - 'Could not find a valid key for fingerprints: {}' - .format(', '.join([repr(f) for f in fingerprints])) - ) - - with BinaryWriter() as req_dh_params_writer: - req_dh_params_writer.write_int( - 0xd712e4be, signed=False) # Req DH Params - req_dh_params_writer.write(nonce) - req_dh_params_writer.write(server_nonce) - req_dh_params_writer.tgwrite_bytes(rsa.get_byte_array(min(p, q))) - req_dh_params_writer.tgwrite_bytes(rsa.get_byte_array(max(p, q))) - req_dh_params_writer.write(target_fingerprint) - req_dh_params_writer.tgwrite_bytes(cipher_text) - - req_dh_params_bytes = req_dh_params_writer.get_bytes() - sender.send(req_dh_params_bytes) + req_dh_params = ReqDHParamsRequest( + nonce=res_pq.nonce, + server_nonce=res_pq.server_nonce, + p=p, q=q, + public_key_fingerprint=target_fingerprint, + encrypted_data=cipher_text + ) + sender.send(req_dh_params.to_bytes()) # Step 2 response: DH Exchange - encrypted_answer = None with BinaryReader(sender.receive()) as reader: - response_code = reader.read_int(signed=False) + req_dh_params.on_response(reader) - if response_code == 0x79cb045d: - raise SecurityError('Server DH params fail: TODO') + server_dh_params = req_dh_params.result + if isinstance(server_dh_params, ServerDHParamsFail): + raise SecurityError('Server DH params fail: TODO') - if response_code != 0xd0e8075c: - raise TypeNotFoundError(response_code) + if not isinstance(server_dh_params, ServerDHParamsOk): + raise AssertionError(server_dh_params) - nonce_from_server = reader.read(16) - if nonce_from_server != nonce: - raise SecurityError('Invalid nonce from server') + if server_dh_params.nonce != res_pq.nonce: + raise SecurityError('Invalid nonce from server') - server_nonce_from_server = reader.read(16) - if server_nonce_from_server != server_nonce: - raise SecurityError('Invalid server nonce from server') - - encrypted_answer = reader.tgread_bytes() + if server_dh_params.server_nonce != res_pq.server_nonce: + raise SecurityError('Invalid server nonce from server') # Step 3 sending: Complete DH Exchange - key, iv = utils.generate_key_data_from_nonce(server_nonce, new_nonce) - plain_text_answer = AES.decrypt_ige(encrypted_answer, key, iv) + key, iv = utils.generate_key_data_from_nonce( + res_pq.server_nonce, new_nonce + ) + plain_text_answer = AES.decrypt_ige( + server_dh_params.encrypted_answer, key, iv + ) - g, dh_prime, ga, time_offset = None, None, None, None - with BinaryReader(plain_text_answer) as dh_inner_data_reader: - dh_inner_data_reader.read(20) # hash sum - code = dh_inner_data_reader.read_int(signed=False) - if code != 0xb5890dba: - raise TypeNotFoundError(code) + with BinaryReader(plain_text_answer) as reader: + reader.read(20) # hash sum + server_dh_inner = reader.tgread_object() + if not isinstance(server_dh_inner, ServerDHInnerData): + raise AssertionError(server_dh_inner) - nonce_from_server1 = dh_inner_data_reader.read(16) - if nonce_from_server1 != nonce: - raise SecurityError('Invalid nonce in encrypted answer') + if server_dh_inner.nonce != res_pq.nonce: + print(server_dh_inner.nonce, res_pq.nonce) + raise SecurityError('Invalid nonce in encrypted answer') - server_nonce_from_server1 = dh_inner_data_reader.read(16) - if server_nonce_from_server1 != server_nonce: - raise SecurityError('Invalid server nonce in encrypted answer') + if server_dh_inner.server_nonce != res_pq.server_nonce: + raise SecurityError('Invalid server nonce in encrypted answer') - g = dh_inner_data_reader.read_int() - dh_prime = get_int(dh_inner_data_reader.tgread_bytes(), signed=False) - ga = get_int(dh_inner_data_reader.tgread_bytes(), signed=False) - - server_time = dh_inner_data_reader.read_int() - time_offset = server_time - int(time.time()) + dh_prime = get_int(server_dh_inner.dh_prime, signed=False) + g_a = get_int(server_dh_inner.g_a, signed=False) + time_offset = server_dh_inner.server_time - int(time.time()) b = get_int(os.urandom(256), signed=False) - gb = pow(g, b, dh_prime) - gab = pow(ga, b, dh_prime) + gb = pow(server_dh_inner.g, b, dh_prime) + gab = pow(g_a, b, dh_prime) # Prepare client DH Inner Data - with BinaryWriter() as client_dh_inner_data_writer: - client_dh_inner_data_writer.write_int( - 0x6643b654, signed=False) # Client DH Inner Data - client_dh_inner_data_writer.write(nonce) - client_dh_inner_data_writer.write(server_nonce) - client_dh_inner_data_writer.write_long(0) # TODO retry_id - client_dh_inner_data_writer.tgwrite_bytes(rsa.get_byte_array(gb)) + client_dh_inner = ClientDHInnerData( + nonce=res_pq.nonce, + server_nonce=res_pq.server_nonce, + retry_id=0, # TODO Actual retry ID + g_b=rsa.get_byte_array(gb) + ).to_bytes() - with BinaryWriter() as client_dh_inner_data_with_hash_writer: - client_dh_inner_data_with_hash_writer.write( - sha1(client_dh_inner_data_writer.get_bytes()).digest()) - - client_dh_inner_data_with_hash_writer.write( - client_dh_inner_data_writer.get_bytes()) - - client_dh_inner_data_bytes = \ - client_dh_inner_data_with_hash_writer.get_bytes() + client_dh_inner_hashed = sha1(client_dh_inner).digest() + client_dh_inner # Encryption - client_dh_inner_data_encrypted_bytes = AES.encrypt_ige( - client_dh_inner_data_bytes, key, iv) + client_dh_encrypted = AES.encrypt_ige(client_dh_inner_hashed, key, iv) # Prepare Set client DH params - with BinaryWriter() as set_client_dh_params_writer: - set_client_dh_params_writer.write_int(0xf5045f1f, signed=False) - set_client_dh_params_writer.write(nonce) - set_client_dh_params_writer.write(server_nonce) - set_client_dh_params_writer.tgwrite_bytes( - client_dh_inner_data_encrypted_bytes) - - set_client_dh_params_bytes = set_client_dh_params_writer.get_bytes() - sender.send(set_client_dh_params_bytes) + set_client_dh = SetClientDHParamsRequest( + nonce=res_pq.nonce, + server_nonce=res_pq.server_nonce, + encrypted_data=client_dh_encrypted, + ) + sender.send(set_client_dh.to_bytes()) # Step 3 response: Complete DH Exchange with BinaryReader(sender.receive()) as reader: - # Everything read from the server, disconnect now - sender.disconnect() + set_client_dh.on_response(reader) - code = reader.read_int(signed=False) - if code == 0x3bcbf734: # DH Gen OK - nonce_from_server = reader.read(16) - if nonce_from_server != nonce: - raise SecurityError('Invalid nonce from server') + dh_gen = set_client_dh.result + if isinstance(dh_gen, DhGenOk): + if dh_gen.nonce != res_pq.nonce: + raise SecurityError('Invalid nonce from server') - server_nonce_from_server = reader.read(16) - if server_nonce_from_server != server_nonce: - raise SecurityError('Invalid server nonce from server') + if dh_gen.server_nonce != res_pq.server_nonce: + raise SecurityError('Invalid server nonce from server') - new_nonce_hash1 = reader.read(16) - auth_key = AuthKey(rsa.get_byte_array(gab)) + auth_key = AuthKey(rsa.get_byte_array(gab)) + new_nonce_hash = int.from_bytes( + auth_key.calc_new_nonce_hash(new_nonce, 1), 'little', signed=True + ) - new_nonce_hash_calculated = auth_key.calc_new_nonce_hash(new_nonce, - 1) - if new_nonce_hash1 != new_nonce_hash_calculated: - raise SecurityError('Invalid new nonce hash') + if dh_gen.new_nonce_hash1 != new_nonce_hash: + raise SecurityError('Invalid new nonce hash') - return auth_key, time_offset + return auth_key, time_offset - elif code == 0x46dc1fb9: # DH Gen Retry - raise NotImplementedError('dh_gen_retry') + elif isinstance(dh_gen, DhGenRetry): + raise NotImplementedError('DhGenRetry') - elif code == 0xa69dae02: # DH Gen Fail - raise NotImplementedError('dh_gen_fail') + elif isinstance(dh_gen, DhGenFail): + raise NotImplementedError('DhGenFail') - else: - raise NotImplementedError('DH Gen unknown: {}'.format(hex(code))) + else: + raise NotImplementedError('DH Gen unknown: {}'.format(dh_gen)) def get_int(byte_array, signed=True): diff --git a/telethon/network/connection.py b/telethon/network/connection.py index 1426ce78..28c548eb 100644 --- a/telethon/network/connection.py +++ b/telethon/network/connection.py @@ -1,10 +1,13 @@ import os +import struct from datetime import timedelta from zlib import crc32 from enum import Enum +import errno + from ..crypto import AESModeCTR -from ..extensions import BinaryWriter, TcpClient +from ..extensions import TcpClient from ..errors import InvalidChecksumError @@ -75,9 +78,15 @@ class Connection: setattr(self, 'read', self._read_plain) def connect(self): - self._send_counter = 0 - self.conn.connect(self.ip, self.port) + try: + self.conn.connect(self.ip, self.port) + except OSError as e: + if e.errno == errno.EISCONN: + return # Already connected, no need to re-set everything up + else: + raise + self._send_counter = 0 if self._mode == ConnectionMode.TCP_ABRIDGED: self.conn.write(b'\xef') elif self._mode == ConnectionMode.TCP_INTERMEDIATE: @@ -85,6 +94,9 @@ class Connection: elif self._mode == ConnectionMode.TCP_OBFUSCATED: self._setup_obfuscation() + def get_timeout(self): + return self.conn.timeout + def _setup_obfuscation(self): # Obfuscated messages secrets cannot start with any of these keywords = (b'PVrG', b'GET ', b'POST', b'\xee' * 4) @@ -118,6 +130,13 @@ class Connection: def close(self): self.conn.close() + def clone(self): + """Creates a copy of this Connection""" + return Connection(self.ip, self.port, + mode=self._mode, + proxy=self.conn.proxy, + timeout=self.conn.timeout) + # region Receive message implementations def recv(self): @@ -164,30 +183,22 @@ class Connection: # https://core.telegram.org/mtproto#tcp-transport # total length, sequence number, packet and checksum (CRC32) length = len(message) + 12 - with BinaryWriter(known_length=length) as writer: - writer.write_int(length) - writer.write_int(self._send_counter) - writer.write(message) - writer.write_int(crc32(writer.get_bytes()), signed=False) - self._send_counter += 1 - self.write(writer.get_bytes()) + data = struct.pack('> 2 - if length < 127: - writer.write_byte(length) - else: - writer.write_byte(127) - writer.write(int.to_bytes(length, 3, 'little')) - writer.write(message) - self.write(writer.get_bytes()) + length = len(message) >> 2 + if length < 127: + length = struct.pack('B', length) + else: + length = b'\x7f' + int.to_bytes(length, 3, 'little') + + self.write(length + message) # endregion diff --git a/telethon/network/mtproto_plain_sender.py b/telethon/network/mtproto_plain_sender.py index 5ced50a9..c7c021be 100644 --- a/telethon/network/mtproto_plain_sender.py +++ b/telethon/network/mtproto_plain_sender.py @@ -1,7 +1,8 @@ +import struct import time from ..errors import BrokenAuthKeyError -from ..extensions import BinaryReader, BinaryWriter +from ..extensions import BinaryReader class MtProtoPlainSender: @@ -25,14 +26,9 @@ class MtProtoPlainSender: """Sends a plain packet (auth_key_id = 0) containing the given message body (data) """ - with BinaryWriter(known_length=len(data) + 20) as writer: - writer.write_long(0) - writer.write_long(self._get_new_msg_id()) - writer.write_int(len(data)) - writer.write(data) - - packet = writer.get_bytes() - self._connection.send(packet) + self._connection.send( + struct.pack(' self._last_ping + self._ping_delay: + self._sender.send(PingRequest( + int.from_bytes(os.urandom(8), 'big', signed=True) + )) + self._last_ping = datetime.now() + + self._sender.receive(update_state=self.updates) + except TimeoutError: + # No problem. + pass + except ConnectionResetError: + self._logger.debug('Server disconnected us. Reconnecting...') + while self._user_connected and not self._reconnect(): + sleep(0.1) # Retry forever, this is instant messaging + + except Exception as error: + # Unknown exception, pass it to the main thread + self._logger.debug( + '[ERROR] Unknown error on the read thread, please report', + error + ) + # If something strange happens we don't want to enter an + # infinite loop where all we do is raise an exception, so + # add a little sleep to avoid the CPU usage going mad. + sleep(0.1) + break + + self._recv_thread = None + + # endregion diff --git a/telethon/telegram_client.py b/telethon/telegram_client.py index 5d3f5ae2..85346b42 100644 --- a/telethon/telegram_client.py +++ b/telethon/telegram_client.py @@ -1,20 +1,21 @@ import os -import threading from datetime import datetime, timedelta from functools import lru_cache from mimetypes import guess_type -from threading import Thread + +try: + import socks +except ImportError: + socks = None from . import TelegramBareClient from . import helpers as utils from .errors import ( RPCError, UnauthorizedError, InvalidParameterError, PhoneCodeEmptyError, - PhoneMigrateError, NetworkMigrateError, UserMigrateError, PhoneCodeExpiredError, PhoneCodeHashEmptyError, PhoneCodeInvalidError ) from .network import ConnectionMode -from .tl import Session, TLObject -from .tl.functions import PingRequest +from .tl import TLObject from .tl.functions.account import ( GetPasswordRequest ) @@ -29,9 +30,6 @@ from .tl.functions.messages import ( GetDialogsRequest, GetHistoryRequest, ReadHistoryRequest, SendMediaRequest, SendMessageRequest ) -from .tl.functions.updates import ( - GetStateRequest -) from .tl.functions.users import ( GetUsersRequest ) @@ -43,6 +41,7 @@ from .tl.types import ( InputUserSelf, UserProfilePhoto, ChatPhoto, UpdateMessageID, UpdateNewMessage, UpdateShortSentMessage ) +from .tl.types.messages import DialogsSlice from .utils import find_user_or_chat, get_extension @@ -59,8 +58,9 @@ class TelegramClient(TelegramBareClient): def __init__(self, session, api_id, api_hash, connection_mode=ConnectionMode.TCP_FULL, proxy=None, - process_updates=False, + update_workers=None, timeout=timedelta(seconds=5), + spawn_read_thread=True, **kwargs): """Initializes the Telegram client with the specified API ID and Hash. @@ -73,15 +73,21 @@ class TelegramClient(TelegramBareClient): This will only affect how messages are sent over the network and how much processing is required before sending them. - If 'process_updates' is set to True, incoming updates will be - processed and you must manually call 'self.updates.poll()' from - another thread to retrieve the saved update objects, or your - memory will fill with these. You may modify the value of - 'self.updates.polling' at any later point. + The integer 'update_workers' represents depending on its value: + is None: Updates will *not* be stored in memory. + = 0: Another thread is responsible for calling self.updates.poll() + > 0: 'update_workers' background threads will be spawned, any + any of them will invoke all the self.updates.handlers. - Despite the value of 'process_updates', if you later call - '.add_update_handler(...)', updates will also be processed - and the update objects will be passed to the handlers you added. + If 'spawn_read_thread', a background thread will be started once + an authorized user has been logged in to Telegram to read items + (such as updates and responses) from the network as soon as they + occur, which will speed things up. + + If you don't want to spawn any additional threads, pending updates + will be read and processed accordingly after invoking a request + and not immediately. This is useful if you don't care about updates + at all and have set 'update_workers=None'. If more named arguments are provided as **kwargs, they will be used to update the Session instance. Most common settings are: @@ -92,210 +98,55 @@ class TelegramClient(TelegramBareClient): system_lang_code = lang_code report_errors = True """ - if not api_id or not api_hash: - raise PermissionError( - "Your API ID or Hash cannot be empty or None. " - "Refer to Telethon's README.rst for more information.") - - # Determine what session object we have - if isinstance(session, str) or session is None: - session = Session.try_load_or_create_new(session) - elif not isinstance(session, Session): - raise ValueError( - 'The given session must be a str or a Session instance.') - super().__init__( session, api_id, api_hash, connection_mode=connection_mode, proxy=proxy, - process_updates=process_updates, - timeout=timeout + update_workers=update_workers, + spawn_read_thread=spawn_read_thread, + timeout=timeout, + **kwargs ) - # Used on connection - the user may modify these and reconnect - kwargs['app_version'] = kwargs.get('app_version', self.__version__) - for name, value in kwargs.items(): - if hasattr(self.session, name): - setattr(self.session, name, value) - - self._updates_thread = None + # Some fields to easy signing in self._phone_code_hash = None self._phone = None - # Uploaded files cache so subsequent calls are instant - self._upload_cache = {} - - # Constantly read for results and updates from within the main client - self._recv_thread = None - - # Default PingRequest delay - self._last_ping = datetime.now() - self._ping_delay = timedelta(minutes=1) - - # endregion - - # region Connecting - - def connect(self, exported_auth=None): - """Connects to the Telegram servers, executing authentication if - required. Note that authenticating to the Telegram servers is - not the same as authenticating the desired user itself, which - may require a call (or several) to 'sign_in' for the first time. - - exported_auth is meant for internal purposes and can be ignored. - """ - if self._sender and self._sender.is_connected(): - return - - ok = super().connect(exported_auth=exported_auth) - # The main TelegramClient is the only one that will have - # constant_read, since it's also the only one who receives - # updates and need to be processed as soon as they occur. - # - # TODO Allow to disable this to avoid the creation of a new thread - # if the user is not going to work with updates at all? Whether to - # read constantly or not for updates needs to be known before hand, - # and further updates won't be able to be added unless allowing to - # switch the mode on the fly. - if ok: - self._recv_thread = Thread( - name='ReadThread', daemon=True, - target=self._recv_thread_impl - ) - self._recv_thread.start() - if self.updates.polling: - self.sync_updates() - - return ok - - def disconnect(self): - """Disconnects from the Telegram server - and stops all the spawned threads""" - if not self._sender or not self._sender.is_connected(): - return - - # The existing thread will close eventually, since it's - # only running while the MtProtoSender.is_connected() - self._recv_thread = None - - # This will trigger a "ConnectionResetError", usually, the background - # thread would try restarting the connection but since the - # ._recv_thread = None, it knows it doesn't have to. - super().disconnect() - - # Also disconnect all the cached senders - for sender in self._cached_clients.values(): - sender.disconnect() - - self._cached_clients.clear() - - # endregion - - # region Working with different connections - - def create_new_connection(self, on_dc=None, timeout=timedelta(seconds=5)): - """Creates a new connection which can be used in parallel - with the original TelegramClient. A TelegramBareClient - will be returned already connected, and the caller is - responsible to disconnect it. - - If 'on_dc' is None, the new client will run on the same - data center as the current client (most common case). - - If the client is meant to be used on a different data - center, the data center ID should be specified instead. - """ - if on_dc is None: - client = TelegramBareClient( - self.session, self.api_id, self.api_hash, - proxy=self.proxy, timeout=timeout - ) - client.connect() - else: - client = self._get_exported_client(on_dc, bypass_cache=True) - - return client - # endregion # region Telegram requests functions - def invoke(self, request, *args, **kwargs): - """Invokes (sends) a MTProtoRequest and returns (receives) its result. - An optional 'retries' parameter can be set. - - *args will be ignored. - """ - if self._recv_thread is not None and \ - threading.get_ident() == self._recv_thread.ident: - raise AssertionError('Cannot invoke requests from the ReadThread') - - self.updates.check_error() - - try: - # Users may call this method from within some update handler. - # If this is the case, then the thread invoking the request - # will be the one which should be reading (but is invoking the - # request) thus not being available to read it "in the background" - # and it's needed to call receive. - return super().invoke( - request, call_receive=self._recv_thread is None, - retries=kwargs.get('retries', 5) - ) - - except (PhoneMigrateError, NetworkMigrateError, UserMigrateError) as e: - self._logger.debug('DC error when invoking request, ' - 'attempting to reconnect at DC {}' - .format(e.new_dc)) - - self.reconnect(new_dc=e.new_dc) - return self.invoke(request) - - # Let people use client(SomeRequest()) instead client.invoke(...) - __call__ = invoke - - def invoke_on_dc(self, request, dc_id, reconnect=False): - """Invokes the given request on a different DC - by making use of the exported MtProtoSenders. - - If 'reconnect=True', then the a reconnection will be performed and - ConnectionResetError will be raised if it occurs a second time. - """ - try: - client = self._get_exported_client( - dc_id, init_connection=reconnect) - - return client.invoke(request) - - except ConnectionResetError: - if reconnect: - raise - else: - return self.invoke_on_dc(request, dc_id, reconnect=True) - # region Authorization requests - def is_user_authorized(self): - """Has the user been authorized yet - (code request sent and confirmed)?""" - return self.session and self.get_me() is not None - def send_code_request(self, phone): """Sends a code request to the specified phone number""" - result = self( - SendCodeRequest(phone, self.api_id, self.api_hash)) + if isinstance(phone, int): + phone = str(phone) + elif phone.startswith('+'): + phone = phone.strip('+') + + result = self(SendCodeRequest(phone, self.api_id, self.api_hash)) self._phone = phone self._phone_code_hash = result.phone_code_hash return result def sign_in(self, phone=None, code=None, - password=None, bot_token=None): + password=None, bot_token=None, phone_code_hash=None): """Completes the sign in process with the phone number + code pair. If no phone or code is provided, then the sole password will be used. The password should be used after a normal authorization attempt has happened and an SessionPasswordNeededError was raised. + If you're calling .sign_in() on two completely different clients + (for example, through an API that creates a new client per phone), + you must first call .sign_in(phone) to receive the code, and then + with the result such method results, call + .sign_in(phone, code, phone_code_hash=result.phone_code_hash). + + If this is done on the same client, the client will fill said values + for you. + To login as a bot, only `bot_token` should be provided. This should equal to the bot access hash provided by https://t.me/BotFather during your bot creation. @@ -306,64 +157,66 @@ class TelegramClient(TelegramBareClient): if phone and not code: return self.send_code_request(phone) elif code: - if self._phone is None: + phone = phone or self._phone + phone_code_hash = phone_code_hash or self._phone_code_hash + if not phone: raise ValueError( - 'Please make sure to call send_code_request first.') + 'Please make sure to call send_code_request first.' + ) + if not phone_code_hash: + raise ValueError('You also need to provide a phone_code_hash.') try: if isinstance(code, int): code = str(code) result = self(SignInRequest( - self._phone, self._phone_code_hash, code + phone, phone_code_hash, code )) except (PhoneCodeEmptyError, PhoneCodeExpiredError, PhoneCodeHashEmptyError, PhoneCodeInvalidError): return None - elif password: salt = self(GetPasswordRequest()).current_salt - result = self( - CheckPasswordRequest(utils.get_password_hash(password, salt))) - + result = self(CheckPasswordRequest( + utils.get_password_hash(password, salt) + )) elif bot_token: result = self(ImportBotAuthorizationRequest( flags=0, bot_auth_token=bot_token, - api_id=self.api_id, api_hash=self.api_hash)) - + api_id=self.api_id, api_hash=self.api_hash + )) else: raise ValueError( 'You must provide a phone and a code the first time, ' - 'and a password only if an RPCError was raised before.') + 'and a password only if an RPCError was raised before.' + ) + self._set_connected_and_authorized() return result.user def sign_up(self, code, first_name, last_name=''): """Signs up to Telegram. Make sure you sent a code request first!""" - return self(SignUpRequest( + result = self(SignUpRequest( phone_number=self._phone, phone_code_hash=self._phone_code_hash, phone_code=code, first_name=first_name, last_name=last_name - )).user + )) + + self._set_connected_and_authorized() + return result.user def log_out(self): """Logs out and deletes the current session. Returns True if everything went okay.""" - # Special flag when logging out (so the ack request confirms it) - self._sender.logging_out = True - try: self(LogOutRequest()) - # The server may have already disconnected us, we still - # try to disconnect to make sure. - self.disconnect() - except (RPCError, ConnectionError): - # Something happened when logging out, restore the state back - self._sender.logging_out = False + except RPCError: return False + self.disconnect() self.session.delete() self.session = None return True @@ -386,22 +239,61 @@ class TelegramClient(TelegramBareClient): offset_id=0, offset_peer=InputPeerEmpty()): """Returns a tuple of lists ([dialogs], [entities]) - with at least 'limit' items each. + with at least 'limit' items each unless all dialogs were consumed. + + If `limit` is None, all dialogs will be retrieved (from the given + offset) will be retrieved. - If `limit` is 0, all dialogs will (should) retrieved. The `entities` represent the user, chat or channel - corresponding to that dialog. + corresponding to that dialog. If it's an integer, not + all dialogs may be retrieved at once. """ + if limit is None: + limit = float('inf') - r = self( - GetDialogsRequest( + dialogs = {} # Use Dialog.top_message as identifier to avoid dupes + messages = {} # Used later for sorting TODO also return these? + entities = {} + while len(dialogs) < limit: + r = self(GetDialogsRequest( offset_date=offset_date, offset_id=offset_id, offset_peer=offset_peer, - limit=limit)) + limit=0 # limit 0 often means "as much as possible" + )) + if not r.dialogs: + break + + for d in r.dialogs: + dialogs[d.top_message] = d + for m in r.messages: + messages[m.id] = m + + # We assume users can't have the same ID as a chat + for u in r.users: + entities[u.id] = u + for c in r.chats: + entities[c.id] = c + + if isinstance(r, DialogsSlice): + # Don't enter next iteration if we already got all + break + + offset_date = r.messages[-1].date + offset_peer = find_user_or_chat(r.dialogs[-1].peer, entities, + entities) + offset_id = r.messages[-1].id & 4294967296 # Telegram/danog magic + + # Sort by message date + no_date = datetime.fromtimestamp(0) + dialogs = sorted( + list(dialogs.values()), + key=lambda d: getattr(messages[d.top_message], 'date', no_date) + ) return ( - r.dialogs, - [find_user_or_chat(d.peer, r.users, r.chats) for d in r.dialogs]) + dialogs, + [find_user_or_chat(d.peer, entities, entities) for d in dialogs] + ) # endregion @@ -427,7 +319,7 @@ class TelegramClient(TelegramBareClient): reply_to_msg_id=self._get_reply_to(reply_to) ) result = self(request) - if isinstance(request, UpdateShortSentMessage): + if isinstance(result, UpdateShortSentMessage): return Message( id=result.id, to_id=entity, @@ -540,7 +432,7 @@ class TelegramClient(TelegramBareClient): return reply_to if isinstance(reply_to, TLObject) and \ - type(reply_to).subclass_of_id == 0x790009e3: + type(reply_to).SUBCLASS_OF_ID == 0x790009e3: # hex(crc32(b'Message')) = 0x790009e3 return reply_to.id @@ -972,77 +864,3 @@ class TelegramClient(TelegramBareClient): ) # endregion - - # region Updates handling - - def sync_updates(self): - """Synchronizes self.updates to their initial state. Will be - called automatically on connection if self.updates.enabled = True, - otherwise it should be called manually after enabling updates. - """ - try: - self.updates.process(self(GetStateRequest())) - return True - except UnauthorizedError: - return False - - def add_update_handler(self, handler): - """Adds an update handler (a function which takes a TLObject, - an update, as its parameter) and listens for updates""" - sync = not self.updates.handlers - self.updates.handlers.append(handler) - if sync: - self.sync_updates() - - def remove_update_handler(self, handler): - self.updates.handlers.remove(handler) - - def list_update_handlers(self): - return self.updates.handlers[:] - - # endregion - - # Constant read - - # By using this approach, another thread will be - # created and started upon connection to constantly read - # from the other end. Otherwise, manual calls to .receive() - # must be performed. The MtProtoSender cannot be connected, - # or an error will be thrown. - # - # This way, sending and receiving will be completely independent. - def _recv_thread_impl(self): - while self._sender and self._sender.is_connected(): - try: - if datetime.now() > self._last_ping + self._ping_delay: - self._sender.send(PingRequest( - int.from_bytes(os.urandom(8), 'big', signed=True) - )) - self._last_ping = datetime.now() - - self._sender.receive(update_state=self.updates) - except AttributeError: - # 'NoneType' object has no attribute 'receive'. - # The only moment when this can happen is reconnection - # was triggered from another thread and the ._sender - # was set to None, so close this thread and exit by return. - self._recv_thread = None - return - except TimeoutError: - # No problem. - pass - except ConnectionResetError: - if self._recv_thread is not None: - # Do NOT attempt reconnecting unless the connection was - # finished by the user -> ._recv_thread is None - self._logger.debug('Server disconnected us. Reconnecting...') - self._recv_thread = None # Not running anymore - self.reconnect() - return - except Exception as e: - # Unknown exception, pass it to the main thread - self.updates.set_error(e) - self._recv_thread = None - return - - # endregion diff --git a/telethon/tl/__init__.py b/telethon/tl/__init__.py index 9ee6a979..403e481a 100644 --- a/telethon/tl/__init__.py +++ b/telethon/tl/__init__.py @@ -1,2 +1,5 @@ from .tlobject import TLObject from .session import Session +from .gzip_packed import GzipPacked +from .tl_message import TLMessage +from .message_container import MessageContainer diff --git a/telethon/tl/gzip_packed.py b/telethon/tl/gzip_packed.py new file mode 100644 index 00000000..05453d4b --- /dev/null +++ b/telethon/tl/gzip_packed.py @@ -0,0 +1,38 @@ +import gzip +import struct + +from . import TLObject + + +class GzipPacked(TLObject): + CONSTRUCTOR_ID = 0x3072cfa1 + + def __init__(self, data): + super().__init__() + self.data = data + + @staticmethod + def gzip_if_smaller(request): + """Calls request.to_bytes(), and based on a certain threshold, + optionally gzips the resulting data. If the gzipped data is + smaller than the original byte array, this is returned instead. + + Note that this only applies to content related requests. + """ + data = request.to_bytes() + # TODO This threshold could be configurable + if request.content_related and len(data) > 512: + gzipped = GzipPacked(data).to_bytes() + return gzipped if len(gzipped) < len(data) else data + else: + return data + + def to_bytes(self): + # TODO Maybe compress level could be an option + return struct.pack('> 8) % 256, + (len(data) >> 16) % 256 + ])) + r.append(data) + + r.append(bytes(padding)) + return b''.join(r) + # These should be overrode - def to_dict(self): + def to_dict(self, recursive=True): return {} - def on_send(self, writer): - pass + def to_bytes(self): + return b'' def on_response(self, reader): pass diff --git a/telethon/update_state.py b/telethon/update_state.py index 2f313dea..fa1963a3 100644 --- a/telethon/update_state.py +++ b/telethon/update_state.py @@ -1,6 +1,7 @@ +import logging from collections import deque from datetime import datetime -from threading import RLock, Event +from threading import RLock, Event, Thread from .tl import types as tl @@ -9,78 +10,136 @@ class UpdateState: """Used to hold the current state of processed updates. To retrieve an update, .poll() should be called. """ - def __init__(self, polling): - self._polling = polling + WORKER_POLL_TIMEOUT = 5.0 # Avoid waiting forever on the workers + + def __init__(self, workers=None): + """ + :param workers: This integer parameter has three possible cases: + workers is None: Updates will *not* be stored on self. + workers = 0: Another thread is responsible for calling self.poll() + workers > 0: 'workers' background threads will be spawned, any + any of them will invoke all the self.handlers. + """ + self._workers = workers + self._worker_threads = [] + self.handlers = [] self._updates_lock = RLock() self._updates_available = Event() self._updates = deque() + self._logger = logging.getLogger(__name__) + # https://core.telegram.org/api/updates self._state = tl.updates.State(0, 0, datetime.now(), 0, 0) + self._setup_workers() def can_poll(self): """Returns True if a call to .poll() won't lock""" return self._updates_available.is_set() - def poll(self): - """Polls an update or blocks until an update object is available""" - if not self._polling: - raise ValueError('Updates are not being polled hence not saved.') + def poll(self, timeout=None): + """Polls an update or blocks until an update object is available. + If 'timeout is not None', it should be a floating point value, + and the method will 'return None' if waiting times out. + """ + if not self._updates_available.wait(timeout=timeout): + return - self._updates_available.wait() with self._updates_lock: + if not self._updates_available.is_set(): + return + update = self._updates.popleft() if not self._updates: self._updates_available.clear() if isinstance(update, Exception): - raise update # Some error was set through .set_error() + raise update # Some error was set through (surely StopIteration) return update - def get_polling(self): - return self._polling + def get_workers(self): + return self._workers - def set_polling(self, polling): - self._polling = polling - if not polling: - with self._updates_lock: - self._updates.clear() - - polling = property(fget=get_polling, fset=set_polling) - - def set_error(self, error): - """Sets an error, so that the next call to .poll() will raise it. - Can be (and is) used to pass exceptions between threads. + def set_workers(self, n): + """Changes the number of workers running. + If 'n is None', clears all pending updates from memory. """ - with self._updates_lock: - # Insert at the beginning so the very next poll causes an error - # TODO Should this reset the pts and such? - self._updates.insert(0, error) - self._updates_available.set() + self._stop_workers() + self._workers = n + if n is None: + self._updates.clear() + else: + self._setup_workers() - def check_error(self): - with self._updates_lock: - if self._updates and isinstance(self._updates[0], Exception): - raise self._updates.pop() + workers = property(fget=get_workers, fset=set_workers) + + def _stop_workers(self): + """Raises "StopIterationException" on the worker threads to stop them, + and also clears all of them off the list + """ + if self._workers: + with self._updates_lock: + # Insert at the beginning so the very next poll causes an error + # on all the worker threads + # TODO Should this reset the pts and such? + for _ in range(self._workers): + self._updates.appendleft(StopIteration()) + self._updates_available.set() + + for t in self._worker_threads: + t.join() + + self._worker_threads.clear() + + def _setup_workers(self): + if self._worker_threads or not self._workers: + # There already are workers, or workers is None or 0. Do nothing. + return + + for i in range(self._workers): + thread = Thread( + target=UpdateState._worker_loop, + name='UpdateWorker{}'.format(i), + daemon=True, + args=(self, i) + ) + self._worker_threads.append(thread) + thread.start() + + def _worker_loop(self, wid): + while True: + try: + update = self.poll(timeout=UpdateState.WORKER_POLL_TIMEOUT) + # TODO Maybe people can add different handlers per update type + if update: + for handler in self.handlers: + handler(update) + except StopIteration: + break + except Exception as e: + # We don't want to crash a worker thread due to any reason + self._logger.debug( + '[ERROR] Unhandled exception on worker {}'.format(wid), e + ) def process(self, update): """Processes an update object. This method is normally called by the library itself. """ - if not self._polling and not self.handlers: - return + if self._workers is None: + return # No processing needs to be done if nobody's working with self._updates_lock: if isinstance(update, tl.updates.State): self._state = update - elif not hasattr(update, 'pts') or update.pts > self._state.pts: - self._state.pts = getattr(update, 'pts', self._state.pts) + return # Nothing else to be done - if self._polling: - self._updates.append(update) - self._updates_available.set() + pts = getattr(update, 'pts', self._state.pts) + if hasattr(update, 'pts') and pts <= self._state.pts: + return # We already handled this update - for handler in self.handlers: - handler(update) + self._state.pts = pts + self._updates.append(update) + self._updates_available.set() diff --git a/telethon/utils.py b/telethon/utils.py index aa82c472..273dc962 100644 --- a/telethon/utils.py +++ b/telethon/utils.py @@ -10,8 +10,17 @@ from .tl.types import ( ChatPhoto, InputPeerChannel, InputPeerChat, InputPeerUser, InputPeerEmpty, MessageMediaDocument, MessageMediaPhoto, PeerChannel, InputChannel, UserEmpty, InputUser, InputUserEmpty, InputUserSelf, InputPeerSelf, - PeerChat, PeerUser, User, UserFull, UserProfilePhoto, Document -) + PeerChat, PeerUser, User, UserFull, UserProfilePhoto, Document, + MessageMediaContact, MessageMediaEmpty, MessageMediaGame, MessageMediaGeo, + MessageMediaUnsupported, MessageMediaVenue, InputMediaContact, + InputMediaDocument, InputMediaEmpty, InputMediaGame, + InputMediaGeoPoint, InputMediaPhoto, InputMediaVenue, InputDocument, + DocumentEmpty, InputDocumentEmpty, Message, GeoPoint, InputGeoPoint, + GeoPointEmpty, InputGeoPointEmpty, Photo, InputPhoto, PhotoEmpty, + InputPhotoEmpty, FileLocation, ChatPhotoEmpty, UserProfilePhotoEmpty, + FileLocationUnavailable, InputMediaUploadedDocument, + InputMediaUploadedPhoto, + DocumentAttributeFilename) def get_display_name(entity): @@ -65,13 +74,10 @@ def _raise_cast_fail(entity, target): def get_input_peer(entity): """Gets the input peer for the given "entity" (user, chat or channel). A ValueError is raised if the given entity isn't a supported type.""" - if entity is None: - return None - if not isinstance(entity, TLObject): _raise_cast_fail(entity, 'InputPeer') - if type(entity).subclass_of_id == 0xc91c90b6: # crc32(b'InputPeer') + if type(entity).SUBCLASS_OF_ID == 0xc91c90b6: # crc32(b'InputPeer') return entity if isinstance(entity, User): @@ -109,13 +115,10 @@ def get_input_peer(entity): def get_input_channel(entity): """Similar to get_input_peer, but for InputChannel's alone""" - if entity is None: - return None - if not isinstance(entity, TLObject): _raise_cast_fail(entity, 'InputChannel') - if type(entity).subclass_of_id == 0x40f202fd: # crc32(b'InputChannel') + if type(entity).SUBCLASS_OF_ID == 0x40f202fd: # crc32(b'InputChannel') return entity if isinstance(entity, Channel) or isinstance(entity, ChannelForbidden): @@ -129,13 +132,10 @@ def get_input_channel(entity): def get_input_user(entity): """Similar to get_input_peer, but for InputUser's alone""" - if entity is None: - return None - if not isinstance(entity, TLObject): _raise_cast_fail(entity, 'InputUser') - if type(entity).subclass_of_id == 0xe669bf46: # crc32(b'InputUser') + if type(entity).SUBCLASS_OF_ID == 0xe669bf46: # crc32(b'InputUser') return entity if isinstance(entity, User): @@ -156,27 +156,169 @@ def get_input_user(entity): _raise_cast_fail(entity, 'InputUser') +def get_input_document(document): + """Similar to get_input_peer, but for documents""" + if not isinstance(document, TLObject): + _raise_cast_fail(document, 'InputDocument') + + if type(document).SUBCLASS_OF_ID == 0xf33fdb68: # crc32(b'InputDocument') + return document + + if isinstance(document, Document): + return InputDocument(id=document.id, access_hash=document.access_hash) + + if isinstance(document, DocumentEmpty): + return InputDocumentEmpty() + + if isinstance(document, MessageMediaDocument): + return get_input_document(document.document) + + if isinstance(document, Message): + return get_input_document(document.media) + + _raise_cast_fail(document, 'InputDocument') + + +def get_input_photo(photo): + """Similar to get_input_peer, but for documents""" + if not isinstance(photo, TLObject): + _raise_cast_fail(photo, 'InputPhoto') + + if type(photo).SUBCLASS_OF_ID == 0x846363e0: # crc32(b'InputPhoto') + return photo + + if isinstance(photo, Photo): + return InputPhoto(id=photo.id, access_hash=photo.access_hash) + + if isinstance(photo, PhotoEmpty): + return InputPhotoEmpty() + + _raise_cast_fail(photo, 'InputPhoto') + + +def get_input_geo(geo): + """Similar to get_input_peer, but for geo points""" + if not isinstance(geo, TLObject): + _raise_cast_fail(geo, 'InputGeoPoint') + + if type(geo).SUBCLASS_OF_ID == 0x430d225: # crc32(b'InputGeoPoint') + return geo + + if isinstance(geo, GeoPoint): + return InputGeoPoint(lat=geo.lat, long=geo.long) + + if isinstance(geo, GeoPointEmpty): + return InputGeoPointEmpty() + + if isinstance(geo, MessageMediaGeo): + return get_input_geo(geo.geo) + + if isinstance(geo, Message): + return get_input_geo(geo.media) + + _raise_cast_fail(geo, 'InputGeoPoint') + + +def get_input_media(media, user_caption=None, is_photo=False): + """Similar to get_input_peer, but for media. + + If the media is a file location and is_photo is known to be True, + it will be treated as an InputMediaUploadedPhoto. + """ + if not isinstance(media, TLObject): + _raise_cast_fail(media, 'InputMedia') + + if type(media).SUBCLASS_OF_ID == 0xfaf846f4: # crc32(b'InputMedia') + return media + + if isinstance(media, MessageMediaPhoto): + return InputMediaPhoto( + id=get_input_photo(media.photo), + caption=media.caption if user_caption is None else user_caption, + ttl_seconds=media.ttl_seconds + ) + + if isinstance(media, MessageMediaDocument): + return InputMediaDocument( + id=get_input_document(media.document), + caption=media.caption if user_caption is None else user_caption, + ttl_seconds=media.ttl_seconds + ) + + if isinstance(media, FileLocation): + if is_photo: + return InputMediaUploadedPhoto( + file=media, + caption=user_caption or '' + ) + else: + return InputMediaUploadedDocument( + file=media, + mime_type='application/octet-stream', # unknown, assume bytes + attributes=[DocumentAttributeFilename('unnamed')], + caption=user_caption or '' + ) + + if isinstance(media, MessageMediaGame): + return InputMediaGame(id=media.game.id) + + if isinstance(media, ChatPhoto) or isinstance(media, UserProfilePhoto): + if isinstance(media.photo_big, FileLocationUnavailable): + return get_input_media(media.photo_small, is_photo=True) + else: + return get_input_media(media.photo_big, is_photo=True) + + if isinstance(media, MessageMediaContact): + return InputMediaContact( + phone_number=media.phone_number, + first_name=media.first_name, + last_name=media.last_name + ) + + if isinstance(media, MessageMediaGeo): + return InputMediaGeoPoint(geo_point=get_input_geo(media.geo)) + + if isinstance(media, MessageMediaVenue): + return InputMediaVenue( + geo_point=get_input_geo(media.geo), + title=media.title, + address=media.address, + provider=media.provider, + venue_id=media.venue_id + ) + + if any(isinstance(media, t) for t in ( + MessageMediaEmpty, MessageMediaUnsupported, + FileLocationUnavailable, ChatPhotoEmpty, + UserProfilePhotoEmpty)): + return InputMediaEmpty() + + if isinstance(media, Message): + return get_input_media(media.media) + + _raise_cast_fail(media, 'InputMedia') + + def find_user_or_chat(peer, users, chats): """Finds the corresponding user or chat given a peer. Returns None if it was not found""" - try: - if isinstance(peer, PeerUser): - return next(u for u in users if u.id == peer.user_id) - - elif isinstance(peer, PeerChat): - return next(c for c in chats if c.id == peer.chat_id) - + if isinstance(peer, PeerUser): + peer, where = peer.user_id, users + else: + where = chats + if isinstance(peer, PeerChat): + peer = peer.chat_id elif isinstance(peer, PeerChannel): - return next(c for c in chats if c.id == peer.channel_id) - - except StopIteration: return + peer = peer.channel_id if isinstance(peer, int): - try: return next(u for u in users if u.id == peer) - except StopIteration: pass - - try: return next(c for c in chats if c.id == peer) - except StopIteration: pass + if isinstance(where, dict): + return where.get(peer) + else: + try: + return next(x for x in where if x.id == peer) + except StopIteration: + pass def get_appropriated_part_size(file_size): diff --git a/telethon_generator/parser/tl_object.py b/telethon_generator/parser/tl_object.py index c8ccae83..416bc587 100644 --- a/telethon_generator/parser/tl_object.py +++ b/telethon_generator/parser/tl_object.py @@ -96,6 +96,17 @@ class TLObject: result=match.group(3), is_function=is_function) + def class_name(self): + """Gets the class name following the Python style guidelines""" + + # Courtesy of http://stackoverflow.com/a/31531797/4759433 + result = re.sub(r'_([a-z])', lambda m: m.group(1).upper(), self.name) + result = result[:1].upper() + result[1:].replace('_', '') + # If it's a function, let it end with "Request" to identify them + if self.is_function: + result += 'Request' + return result + def sorted_args(self): """Returns the arguments properly sorted and ready to plug-in into a Python's method header (i.e., flags and those which @@ -197,8 +208,8 @@ class TLArg: else: self.flag_indicator = False self.is_generic = arg_type.startswith('!') - self.type = arg_type.lstrip( - '!') # Strip the exclamation mark always to have only the name + # Strip the exclamation mark always to have only the name + self.type = arg_type.lstrip('!') # The type may be a flag (flags.IDX?REAL_TYPE) # Note that 'flags' is NOT the flags name; this is determined by a previous argument @@ -233,6 +244,24 @@ class TLArg: self.generic_definition = generic_definition + def type_hint(self): + result = { + 'int': 'int', + 'long': 'int', + 'int128': 'int', + 'int256': 'int', + 'string': 'str', + 'date': 'datetime.datetime | None', # None date = 0 timestamp + 'bytes': 'bytes', + 'true': 'bool', + }.get(self.type, 'TLObject') + if self.is_vector: + result = 'list[{}]'.format(result) + if self.is_flag and self.type != 'date': + result += ' | None' + + return result + def __str__(self): # Find the real type representation by updating it as required real_type = self.type diff --git a/telethon_generator/scheme.tl b/telethon_generator/scheme.tl index c0b41dc9..5e949239 100644 --- a/telethon_generator/scheme.tl +++ b/telethon_generator/scheme.tl @@ -32,16 +32,16 @@ /// Authorization key creation /////////////////////////////// -resPQ#05162463 nonce:int128 server_nonce:int128 pq:string server_public_key_fingerprints:Vector = ResPQ; +resPQ#05162463 nonce:int128 server_nonce:int128 pq:bytes server_public_key_fingerprints:Vector = ResPQ; -p_q_inner_data#83c95aec pq:string p:string q:string nonce:int128 server_nonce:int128 new_nonce:int256 = P_Q_inner_data; +p_q_inner_data#83c95aec pq:bytes p:bytes q:bytes nonce:int128 server_nonce:int128 new_nonce:int256 = P_Q_inner_data; server_DH_params_fail#79cb045d nonce:int128 server_nonce:int128 new_nonce_hash:int128 = Server_DH_Params; -server_DH_params_ok#d0e8075c nonce:int128 server_nonce:int128 encrypted_answer:string = Server_DH_Params; +server_DH_params_ok#d0e8075c nonce:int128 server_nonce:int128 encrypted_answer:bytes = Server_DH_Params; -server_DH_inner_data#b5890dba nonce:int128 server_nonce:int128 g:int dh_prime:string g_a:string server_time:int = Server_DH_inner_data; +server_DH_inner_data#b5890dba nonce:int128 server_nonce:int128 g:int dh_prime:bytes g_a:bytes server_time:int = Server_DH_inner_data; -client_DH_inner_data#6643b654 nonce:int128 server_nonce:int128 retry_id:long g_b:string = Client_DH_Inner_Data; +client_DH_inner_data#6643b654 nonce:int128 server_nonce:int128 retry_id:long g_b:bytes = Client_DH_Inner_Data; dh_gen_ok#3bcbf734 nonce:int128 server_nonce:int128 new_nonce_hash1:int128 = Set_client_DH_params_answer; dh_gen_retry#46dc1fb9 nonce:int128 server_nonce:int128 new_nonce_hash2:int128 = Set_client_DH_params_answer; @@ -55,9 +55,9 @@ destroy_auth_key_fail#ea109b13 = DestroyAuthKeyRes; req_pq#60469778 nonce:int128 = ResPQ; -req_DH_params#d712e4be nonce:int128 server_nonce:int128 p:string q:string public_key_fingerprint:long encrypted_data:string = Server_DH_Params; +req_DH_params#d712e4be nonce:int128 server_nonce:int128 p:bytes q:bytes public_key_fingerprint:long encrypted_data:bytes = Server_DH_Params; -set_client_DH_params#f5045f1f nonce:int128 server_nonce:int128 encrypted_data:string = Set_client_DH_params_answer; +set_client_DH_params#f5045f1f nonce:int128 server_nonce:int128 encrypted_data:bytes = Set_client_DH_params_answer; destroy_auth_key#d1435160 = DestroyAuthKeyRes; diff --git a/telethon_generator/tl_generator.py b/telethon_generator/tl_generator.py index e0d207ac..e76dffaa 100644 --- a/telethon_generator/tl_generator.py +++ b/telethon_generator/tl_generator.py @@ -1,6 +1,7 @@ import os import re import shutil +import struct from zlib import crc32 from collections import defaultdict @@ -107,8 +108,7 @@ class TLGenerator: if tlobject.namespace: builder.write('.' + tlobject.namespace) - builder.writeln('.{},'.format( - TLGenerator.get_class_name(tlobject))) + builder.writeln('.{},'.format(tlobject.class_name())) builder.current_indent -= 1 builder.writeln('}') @@ -137,13 +137,29 @@ class TLGenerator: x for x in namespace_tlobjects.keys() if x ))) + # Import 'get_input_*' utils + # TODO Support them on types too + if 'functions' in out_dir: + builder.writeln( + 'from {}.utils import get_input_peer, ' + 'get_input_channel, get_input_user, ' + 'get_input_media'.format('.' * depth) + ) + + # Import 'os' for those needing access to 'os.urandom()' + # Currently only 'random_id' needs 'os' to be imported, + # for all those TLObjects with arg.can_be_inferred. + builder.writeln('import os') + + # Import struct for the .to_bytes(self) serialization + builder.writeln('import struct') + # Generate the class for every TLObject for t in sorted(tlobjects, key=lambda x: x.name): TLGenerator._write_source_code( t, builder, depth, type_constructors ) - while builder.current_indent != 0: - builder.end_block() + builder.current_indent = 0 @staticmethod def _write_source_code(tlobject, builder, depth, type_constructors): @@ -154,35 +170,15 @@ class TLGenerator: the Type: [Constructors] must be given for proper importing and documentation strings. """ - if tlobject.is_function: - util_imports = set() - for a in tlobject.args: - # We can automatically convert some "full" types to - # "input only" (like User -> InputPeerUser, etc.) - if a.type == 'InputPeer': - util_imports.add('get_input_peer') - elif a.type == 'InputChannel': - util_imports.add('get_input_channel') - elif a.type == 'InputUser': - util_imports.add('get_input_user') - - if util_imports: - builder.writeln('from {}.utils import {}'.format( - '.' * depth, ', '.join(util_imports))) - - if any(a for a in tlobject.args if a.can_be_inferred): - # Currently only 'random_id' needs 'os' to be imported - builder.writeln('import os') - builder.writeln() builder.writeln() - builder.writeln('class {}(TLObject):'.format( - TLGenerator.get_class_name(tlobject))) + builder.writeln('class {}(TLObject):'.format(tlobject.class_name())) # Class-level variable to store its Telegram's constructor ID - builder.writeln('constructor_id = {}'.format(hex(tlobject.id))) - builder.writeln('subclass_of_id = {}'.format( - hex(crc32(tlobject.result.encode('ascii'))))) + builder.writeln('CONSTRUCTOR_ID = {}'.format(hex(tlobject.id))) + builder.writeln('SUBCLASS_OF_ID = {}'.format( + hex(crc32(tlobject.result.encode('ascii')))) + ) builder.writeln() # Flag arguments must go last @@ -221,17 +217,10 @@ class TLGenerator: builder.writeln('"""') for arg in args: if not arg.flag_indicator: - builder.write( - ':param {}: Telegram type: "{}".' - .format(arg.name, arg.type) - ) - if arg.is_vector: - builder.write(' Must be a list.'.format(arg.name)) - - if arg.is_generic: - builder.write(' Must be another TLObject request.') - - builder.writeln() + builder.writeln(':param {} {}:'.format( + arg.type_hint(), arg.name + )) + builder.current_indent -= 1 # It will auto-indent (':') # We also want to know what type this request returns # or to which type this constructor belongs to @@ -246,12 +235,11 @@ class TLGenerator: builder.writeln('This type has no constructors.') elif len(constructors) == 1: builder.writeln('Instance of {}.'.format( - TLGenerator.get_class_name(constructors[0]) + constructors[0].class_name() )) else: builder.writeln('Instance of either {}.'.format( - ', '.join(TLGenerator.get_class_name(c) - for c in constructors) + ', '.join(c.class_name() for c in constructors) )) builder.writeln('"""') @@ -274,56 +262,77 @@ class TLGenerator: builder.end_block() # Write the to_dict(self) method + builder.writeln('def to_dict(self, recursive=True):') if args: - builder.writeln('def to_dict(self):') builder.writeln('return {') - builder.current_indent += 1 - - base_types = ('string', 'bytes', 'int', 'long', 'int128', - 'int256', 'double', 'Bool', 'true', 'date') - - for arg in args: - builder.write("'{}': ".format(arg.name)) - if arg.type in base_types: - if arg.is_vector: - builder.write( - '[] if self.{0} is None else self.{0}[:]' - .format(arg.name) - ) - else: - builder.write('self.{}'.format(arg.name)) - else: - if arg.is_vector: - builder.write( - '[] if self.{0} is None else [None ' - 'if x is None else x.to_dict() for x in self.{0}]' - .format(arg.name) - ) - else: - builder.write( - 'None if self.{0} is None else self.{0}.to_dict()' - .format(arg.name) - ) - builder.writeln(',') - - builder.current_indent -= 1 - builder.writeln("}") else: - builder.writeln('@staticmethod') - builder.writeln('def to_dict():') - builder.writeln('return {}') + builder.write('return {') + builder.current_indent += 1 + + base_types = ('string', 'bytes', 'int', 'long', 'int128', + 'int256', 'double', 'Bool', 'true', 'date') + + for arg in args: + builder.write("'{}': ".format(arg.name)) + if arg.type in base_types: + if arg.is_vector: + builder.write('[] if self.{0} is None else self.{0}[:]' + .format(arg.name)) + else: + builder.write('self.{}'.format(arg.name)) + else: + if arg.is_vector: + builder.write( + '([] if self.{0} is None else [None' + ' if x is None else x.to_dict() for x in self.{0}]' + ') if recursive else self.{0}'.format(arg.name) + ) + else: + builder.write( + '(None if self.{0} is None else self.{0}.to_dict())' + ' if recursive else self.{0}'.format(arg.name) + ) + builder.writeln(',') + + builder.current_indent -= 1 + builder.writeln("}") builder.end_block() - # Write the on_send(self, writer) function - builder.writeln('def on_send(self, writer):') - builder.writeln( - 'writer.write_int({}.constructor_id, signed=False)' - .format(TLGenerator.get_class_name(tlobject))) + # Write the .to_bytes() function + builder.writeln('def to_bytes(self):') + + # Some objects require more than one flag parameter to be set + # at the same time. In this case, add an assertion. + repeated_args = defaultdict(list) + for arg in tlobject.args: + if arg.is_flag: + repeated_args[arg.flag_index].append(arg) + + for ra in repeated_args.values(): + if len(ra) > 1: + cnd1 = ('self.{} is None'.format(a.name) for a in ra) + cnd2 = ('self.{} is not None'.format(a.name) for a in ra) + builder.writeln( + "assert ({}) or ({}), '{} parameters must all " + "be None or neither be None'".format( + ' and '.join(cnd1), ' and '.join(cnd2), + ', '.join(a.name for a in ra) + ) + ) + + builder.writeln("return b''.join((") + builder.current_indent += 1 + + # First constructor code, we already know its bytes + builder.writeln('{},'.format(repr(struct.pack(' """ - if arg.generic_definition: return # Do nothing, this only specifies a later type @@ -470,73 +462,91 @@ class TLGenerator: if arg.is_flag: if arg.type == 'true': return # Exit, since True type is never written + elif arg.is_vector: + # Vector flags are special since they consist of 3 values, + # so we need an extra join here. Note that empty vector flags + # should NOT be sent either! + builder.write("b'' if not {} else b''.join((".format(name)) else: - builder.writeln('if {}:'.format(name)) + builder.write("b'' if not {} else (".format(name)) if arg.is_vector: if arg.use_vector_id: - builder.writeln('writer.write_int(0x1cb5c415, signed=False)') + # vector code, unsigned 0x1cb5c415 as little endian + builder.write(r"b'\x15\xc4\xb5\x1c',") + + builder.write("struct.pack('3.5 feature, so add another join. + builder.write("b''.join(") - builder.writeln('writer.write_int(len({}))'.format(name)) - builder.writeln('for _x in {}:'.format(name)) # Temporary disable .is_vector, not to enter this if again - arg.is_vector = False - TLGenerator.write_onsend_code(builder, arg, args, name='_x') + # Also disable .is_flag since it's not needed per element + old_flag = arg.is_flag + arg.is_vector = arg.is_flag = False + TLGenerator.write_to_bytes(builder, arg, args, name='x') arg.is_vector = True + arg.is_flag = old_flag + + builder.write(' for x in {})'.format(name)) elif arg.flag_indicator: # Calculate the flags with those items which are not None - builder.writeln('flags = 0') - for flag in args: - if flag.is_flag: - builder.writeln('flags |= (1 << {}) if {} else 0'.format( - flag.flag_index, 'self.{}'.format(flag.name))) - - builder.writeln('writer.write_int(flags)') - builder.writeln() + builder.write("struct.pack('