diff --git a/requirements.txt b/requirements.txt index 2b650ec4..45e8c141 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,2 +1,3 @@ pyaes rsa +typing diff --git a/setup.py b/setup.py index 143ca0cb..05ca9197 100755 --- a/setup.py +++ b/setup.py @@ -13,7 +13,7 @@ Extra supported commands are: # To use a consistent encoding from codecs import open -from sys import argv +from sys import argv, version_info import os import re @@ -153,7 +153,8 @@ def main(): 'telethon_generator/parser/tl_object.py', 'telethon_generator/parser/tl_parser.py', ]), - install_requires=['pyaes', 'rsa'], + install_requires=['pyaes', 'rsa', + 'typing' if version_info < (3, 5) else ""], extras_require={ 'cryptg': ['cryptg'], 'sqlalchemy': ['sqlalchemy'] diff --git a/telethon_generator/parser/tl_object.py b/telethon_generator/parser/tl_object.py index 034cb3c3..0e0045d7 100644 --- a/telethon_generator/parser/tl_object.py +++ b/telethon_generator/parser/tl_object.py @@ -254,7 +254,7 @@ class TLArg: self.generic_definition = generic_definition - def type_hint(self): + def doc_type_hint(self): result = { 'int': 'int', 'long': 'int', @@ -272,6 +272,27 @@ class TLArg: return result + def python_type_hint(self): + type = self.type + if '.' in type: + type = type.split('.')[1] + result = { + 'int': 'int', + 'long': 'int', + 'int128': 'int', + 'int256': 'int', + 'string': 'str', + 'date': 'Optional[datetime]', # None date = 0 timestamp + 'bytes': 'bytes', + 'true': 'bool', + }.get(type, "Type{}".format(type)) + if self.is_vector: + result = 'List[{}]'.format(result) + if self.is_flag and type != 'date': + result = 'Optional[{}]'.format(result) + + return result + def __str__(self): # Find the real type representation by updating it as required real_type = self.type diff --git a/telethon_generator/tl_generator.py b/telethon_generator/tl_generator.py index ff12acfe..7c1f6237 100644 --- a/telethon_generator/tl_generator.py +++ b/telethon_generator/tl_generator.py @@ -138,6 +138,7 @@ class TLGenerator: builder.writeln( 'from {}.tl.tlobject import TLObject'.format('.' * depth) ) + builder.writeln('from typing import Optional, List, Union, TYPE_CHECKING') # Add the relative imports to the namespaces, # unless we already are in a namespace. @@ -154,13 +155,81 @@ class TLGenerator: # Import struct for the .__bytes__(self) serialization builder.writeln('import struct') + tlobjects.sort(key=lambda x: x.name) + + type_names = set() + type_defs = [] + + # Find all the types in this file and generate type definitions + # based on the types. The type definitions are written to the + # file at the end. + for t in tlobjects: + if not t.is_function: + type_name = t.result + if '.' in type_name: + type_name = type_name[type_name.rindex('.'):] + if type_name in type_names: + continue + type_names.add(type_name) + constructors = type_constructors[type_name] + if not constructors: + pass + elif len(constructors) == 1: + type_defs.append('Type{} = {}'.format( + type_name, constructors[0].class_name())) + else: + type_defs.append('Type{} = Union[{}]'.format( + type_name, ','.join(c.class_name() + for c in constructors))) + + imports = {} + primitives = ('int', 'long', 'int128', 'int256', 'string', + 'date', 'bytes', 'true') + # Find all the types in other files that are used in this file + # and generate the information required to import those types. + for t in tlobjects: + for arg in t.args: + name = arg.type + if not name or name in primitives: + continue + + import_space = '{}.tl.types'.format('.' * depth) + if '.' in name: + namespace = name.split('.')[0] + name = name.split('.')[1] + import_space += '.{}'.format(namespace) + + if name not in type_names: + type_names.add(name) + if name == 'date': + imports['datetime'] = ['datetime'] + continue + elif not import_space in imports: + imports[import_space] = set() + imports[import_space].add('Type{}'.format(name)) + + # Add imports required for type checking. + builder.writeln('if TYPE_CHECKING:') + for namespace, names in imports.items(): + builder.writeln('from {} import {}'.format( + namespace, ', '.join(names))) + else: + builder.writeln('pass') + builder.end_block() + # Generate the class for every TLObject - for t in sorted(tlobjects, key=lambda x: x.name): + for t in tlobjects: TLGenerator._write_source_code( t, builder, depth, type_constructors ) builder.current_indent = 0 + # Write the type definitions generated earlier. + builder.writeln('') + for line in type_defs: + builder.writeln(line) + + @staticmethod def _write_source_code(tlobject, builder, depth, type_constructors): """Writes the source code corresponding to the given TLObject @@ -218,7 +287,7 @@ class TLGenerator: for arg in args: if not arg.flag_indicator: builder.writeln(':param {} {}:'.format( - arg.type_hint(), arg.name + arg.doc_type_hint(), arg.name )) builder.current_indent -= 1 # It will auto-indent (':') @@ -258,7 +327,8 @@ class TLGenerator: for arg in args: if not arg.can_be_inferred: - builder.writeln('self.{0} = {0}'.format(arg.name)) + builder.writeln('self.{0} = {0} # type: {1}'.format( + arg.name, arg.python_type_hint())) continue # Currently the only argument that can be