Make raw API types immutable

This commit is contained in:
Lonami Exo 2022-01-26 12:14:17 +01:00
parent d426099bf5
commit 070af28e85
5 changed files with 90 additions and 75 deletions

View File

@ -775,3 +775,5 @@ raise_last_call_error is now the default rather than ValueError
self-produced updates like getmessage now also trigger a handler self-produced updates like getmessage now also trigger a handler
input_peer removed from get_me; input peers should remain mostly an impl detail input_peer removed from get_me; input peers should remain mostly an impl detail
raw api types and fns are now immutable. this can enable optimizations in the future.

View File

@ -8,6 +8,7 @@ import time
import typing import typing
import ipaddress import ipaddress
import dataclasses import dataclasses
import functools
from .. import version, __name__ as __base_name__, _tl from .. import version, __name__ as __base_name__, _tl
from .._crypto import rsa from .._crypto import rsa
@ -182,7 +183,8 @@ def init(
default_device_model = system.machine default_device_model = system.machine
default_system_version = re.sub(r'-.+','',system.release) default_system_version = re.sub(r'-.+','',system.release)
self._init_request = _tl.fn.InitConnection( self._init_request = functools.partial(
_tl.fn.InitConnection,
api_id=self._api_id, api_id=self._api_id,
device_model=device_model or default_device_model or 'Unknown', device_model=device_model or default_device_model or 'Unknown',
system_version=system_version or default_system_version or '1.0', system_version=system_version or default_system_version or '1.0',
@ -190,8 +192,6 @@ def init(
lang_code=lang_code, lang_code=lang_code,
system_lang_code=system_lang_code, system_lang_code=system_lang_code,
lang_pack='', # "langPacks are for official apps only" lang_pack='', # "langPacks are for official apps only"
query=None,
proxy=None
) )
self._sender = MTProtoSender( self._sender = MTProtoSender(
@ -272,10 +272,8 @@ async def connect(self: 'TelegramClient') -> None:
# Need to send invokeWithLayer for things to work out. # Need to send invokeWithLayer for things to work out.
# Make the most out of this opportunity by also refreshing our state. # Make the most out of this opportunity by also refreshing our state.
# During the v1 to v2 migration, this also correctly sets the IPv* columns. # During the v1 to v2 migration, this also correctly sets the IPv* columns.
self._init_request.query = _tl.fn.help.GetConfig()
config = await self._sender.send(_tl.fn.InvokeWithLayer( config = await self._sender.send(_tl.fn.InvokeWithLayer(
_tl.LAYER, self._init_request _tl.LAYER, self._init_request(query=_tl.fn.help.GetConfig())
)) ))
for dc in config.dc_options: for dc in config.dc_options:
@ -318,7 +316,6 @@ async def disconnect(self: 'TelegramClient'):
def set_proxy(self: 'TelegramClient', proxy: typing.Union[tuple, dict]): def set_proxy(self: 'TelegramClient', proxy: typing.Union[tuple, dict]):
init_proxy = None init_proxy = None
self._init_request.proxy = init_proxy
self._proxy = proxy self._proxy = proxy
# While `await client.connect()` passes new proxy on each new call, # While `await client.connect()` passes new proxy on each new call,
@ -408,8 +405,9 @@ async def _create_exported_sender(self: 'TelegramClient', dc_id):
)) ))
self._log[__name__].info('Exporting auth for new borrowed sender in %s', dc) self._log[__name__].info('Exporting auth for new borrowed sender in %s', dc)
auth = await self(_tl.fn.auth.ExportAuthorization(dc_id)) auth = await self(_tl.fn.auth.ExportAuthorization(dc_id))
self._init_request.query = _tl.fn.auth.ImportAuthorization(id=auth.id, bytes=auth.bytes) req = _tl.fn.InvokeWithLayer(_tl.LAYER, self._init_request(
req = _tl.fn.InvokeWithLayer(_tl.LAYER, self._init_request) query=_tl.fn.auth.ImportAuthorization(id=auth.id, bytes=auth.bytes)
))
await sender.send(req) await sender.send(req)
return sender return sender

View File

@ -35,10 +35,11 @@ async def _call(self: 'TelegramClient', sender, request, ordered=False, flood_sl
if flood_sleep_threshold is None: if flood_sleep_threshold is None:
flood_sleep_threshold = self.flood_sleep_threshold flood_sleep_threshold = self.flood_sleep_threshold
requests = (request if utils.is_list_like(request) else (request,)) requests = (request if utils.is_list_like(request) else (request,))
new_requests = []
for r in requests: for r in requests:
if not isinstance(r, _tl.TLRequest): if not isinstance(r, _tl.TLRequest):
raise _NOT_A_REQUEST() raise _NOT_A_REQUEST()
await r.resolve(self, utils) r = await r.resolve(self, utils)
# Avoid making the request if it's already in a flood wait # Avoid making the request if it's already in a flood wait
if r.CONSTRUCTOR_ID in self._flood_waited_requests: if r.CONSTRUCTOR_ID in self._flood_waited_requests:
@ -59,6 +60,9 @@ async def _call(self: 'TelegramClient', sender, request, ordered=False, flood_sl
if self._no_updates: if self._no_updates:
r = _tl.fn.InvokeWithoutUpdates(r) r = _tl.fn.InvokeWithoutUpdates(r)
new_requests.append(r)
request = new_requests if utils.is_list_like(request) else new_requests[0]
request_index = 0 request_index = 0
last_error = None last_error = None
self._last_request = time.time() self._last_request = time.time()

View File

@ -155,4 +155,4 @@ class TLRequest(TLObject):
return reader.tgread_object() return reader.tgread_object()
async def resolve(self, client, utils): async def resolve(self, client, utils):
pass return self

View File

@ -85,6 +85,9 @@ def _write_modules(
# Import struct for the .__bytes__(self) serialization # Import struct for the .__bytes__(self) serialization
builder.writeln('import struct') builder.writeln('import struct')
# Import dataclasses in order to freeze the instances
builder.writeln('import dataclasses')
# Import datetime for type hinting # Import datetime for type hinting
builder.writeln('from datetime import datetime') builder.writeln('from datetime import datetime')
@ -187,37 +190,9 @@ def _write_source_code(tlobject, kind, builder, type_constructors):
def _write_class_init(tlobject, kind, type_constructors, builder): def _write_class_init(tlobject, kind, type_constructors, builder):
builder.writeln() builder.writeln()
builder.writeln() builder.writeln()
builder.writeln('@dataclasses.dataclass(init=False, frozen=True)')
builder.writeln('class {}({}):', tlobject.class_name, kind) builder.writeln('class {}({}):', tlobject.class_name, kind)
# Define slots to help reduce the size of the objects a little bit.
# It's also good for knowing what fields an object has.
builder.write('__slots__ = (')
sep = ''
for arg in tlobject.real_args:
builder.write('{}{!r},', sep, arg.name)
sep = ' '
builder.writeln(')')
# Class-level variable to store its Telegram's constructor ID
builder.writeln('CONSTRUCTOR_ID = {:#x}', tlobject.id)
builder.writeln('SUBCLASS_OF_ID = {:#x}',
crc32(tlobject.result.encode('ascii')))
builder.writeln()
# Convert the args to string parameters, flags having =None
args = ['{}: {}{}'.format(
a.name, a.type_hint(), '=None' if a.is_flag or a.can_be_inferred else '')
for a in tlobject.real_args
]
# Write the __init__ function if it has any argument
if not tlobject.real_args:
return
if any(a.name in dir(builtins) for a in tlobject.real_args):
builder.writeln('# noinspection PyShadowingBuiltins')
builder.writeln("def __init__({}):", ', '.join(['self'] + args))
builder.writeln('"""') builder.writeln('"""')
if tlobject.is_function: if tlobject.is_function:
builder.write(':returns {}: ', tlobject.result) builder.write(':returns {}: ', tlobject.result)
@ -236,47 +211,83 @@ def _write_class_init(tlobject, kind, type_constructors, builder):
builder.writeln('"""') builder.writeln('"""')
# Set the arguments # Define slots to help reduce the size of the objects a little bit.
# It's also good for knowing what fields an object has.
builder.write('__slots__ = (')
sep = ''
for arg in tlobject.real_args: for arg in tlobject.real_args:
if not arg.can_be_inferred: builder.write('{}{!r},', sep, arg.name)
builder.writeln('self.{0} = {0}', arg.name) sep = ' '
builder.writeln(')')
# Currently the only argument that can be # Class-level variable to store its Telegram's constructor ID
# inferred are those called 'random_id' builder.writeln('CONSTRUCTOR_ID = {:#x}', tlobject.id)
elif arg.name == 'random_id': builder.writeln('SUBCLASS_OF_ID = {:#x}',
# Endianness doesn't really matter, and 'big' is shorter crc32(tlobject.result.encode('ascii')))
code = "int.from_bytes(os.urandom({}), 'big', signed=True)" \ builder.writeln()
.format(8 if arg.type == 'long' else 4)
if arg.is_vector: # Because we're using __slots__ and frozen instances, we cannot have flags = None directly.
# Currently for the case of "messages.forwardMessages" # See https://stackoverflow.com/q/50180735 (Python 3.10 does offer a solution).
# Ensure we can infer the length from id:Vector<> # Write the __init__ function if it has any argument.
if not next(a for a in tlobject.real_args if tlobject.real_args:
if a.name == 'id').is_vector: # Convert the args to string parameters
raise ValueError( for a in tlobject.real_args:
'Cannot infer list of random ids for ', tlobject builder.writeln('{}: {}', a.name, a.type_hint())
)
code = '[{} for _ in range(len(id))]'.format(code)
builder.writeln( # Convert the args to string parameters, flags having =None
"self.random_id = random_id if random_id " args = ['{}: {}{}'.format(
"is not None else {}", code a.name, a.type_hint(), '=None' if a.is_flag or a.can_be_inferred else '')
) for a in tlobject.real_args
else: ]
raise ValueError('Cannot infer a value for ', arg)
builder.end_block() if any(a.name in dir(builtins) for a in tlobject.real_args):
builder.writeln('# noinspection PyShadowingBuiltins')
builder.writeln("def __init__({}):", ', '.join(['self'] + args))
# Set the arguments
for arg in tlobject.real_args:
builder.writeln("object.__setattr__(self, '{0}', {0})", arg.name)
builder.end_block()
def _write_resolve(tlobject, builder): def _write_resolve(tlobject, builder):
if tlobject.is_function and any( if tlobject.is_function and any(
(arg.type in AUTO_CASTS (arg.can_be_inferred
or ((arg.name, arg.type) in NAMED_AUTO_CASTS or arg.type in AUTO_CASTS
and tlobject.fullname not in NAMED_BLACKLIST)) or ((arg.name, arg.type) in NAMED_AUTO_CASTS and tlobject.fullname not in NAMED_BLACKLIST))
for arg in tlobject.real_args for arg in tlobject.real_args
): ):
builder.writeln('async def resolve(self, client, utils):') builder.writeln('async def resolve(self, client, utils):')
builder.writeln('r = {}') # hold replacements
for arg in tlobject.real_args: for arg in tlobject.real_args:
if arg.can_be_inferred:
builder.writeln('if self.{} is None:', arg.name)
# Currently the only argument that can be
# inferred are those called 'random_id'
if arg.name == 'random_id':
# Endianness doesn't really matter, and 'big' is shorter
code = "int.from_bytes(os.urandom({}), 'big', signed=True)" \
.format(8 if arg.type == 'long' else 4)
if arg.is_vector:
# Currently for the case of "messages.forwardMessages"
# Ensure we can infer the length from id:Vector<>
if not next(a for a in tlobject.real_args if a.name == 'id').is_vector:
raise ValueError('Cannot infer list of random ids for ', tlobject)
code = '[{} for _ in range(len(self.id))]'.format(code)
builder.writeln("r['{}'] = {}", arg.name, code)
else:
raise ValueError('Cannot infer a value for ', arg)
builder.end_block()
continue
ac = AUTO_CASTS.get(arg.type) ac = AUTO_CASTS.get(arg.type)
if not ac: if not ac:
ac = NAMED_AUTO_CASTS.get((arg.name, arg.type)) ac = NAMED_AUTO_CASTS.get((arg.name, arg.type))
@ -287,17 +298,17 @@ def _write_resolve(tlobject, builder):
builder.writeln('if self.{}:', arg.name) builder.writeln('if self.{}:', arg.name)
if arg.is_vector: if arg.is_vector:
builder.writeln('_tmp = []') builder.writeln("r['{}'] = []", arg.name)
builder.writeln('for _x in self.{0}:', arg.name) builder.writeln('for x in self.{0}:', arg.name)
builder.writeln('_tmp.append({})', ac.format('_x')) builder.writeln("r['{}'].append({})", arg.name, ac.format('x'))
builder.end_block() builder.end_block()
builder.writeln('self.{} = _tmp', arg.name)
else: else:
builder.writeln('self.{} = {}', arg.name, builder.writeln("r['{}'] = {}", arg.name, ac.format('self.' + arg.name))
ac.format('self.' + arg.name))
if arg.is_flag: if arg.is_flag:
builder.end_block() builder.end_block()
builder.writeln('return dataclasses.replace(self, **r)')
builder.end_block() builder.end_block()