Use __slots__ in all generated classes

This commit is contained in:
Lonami Exo 2021-09-26 17:52:16 +02:00
parent 8bd4835eb2
commit 86c47a2771
5 changed files with 38 additions and 26 deletions

View File

@ -286,8 +286,8 @@ results into a list:
// TODO does the download really need to be special? get download is kind of weird though // TODO does the download really need to be special? get download is kind of weird though
Raw API methods have been renamed and are now considered private Raw API has been renamed and is now considered private
---------------------------------------------------------------- ------------------------------------------------------
The subpackage holding the raw API methods has been renamed from ``tl`` to ``_tl`` in order to The subpackage holding the raw API methods has been renamed from ``tl`` to ``_tl`` in order to
signal that these are prone to change across minor version bumps (the ``y`` in version ``x.y.z``). signal that these are prone to change across minor version bumps (the ``y`` in version ``x.y.z``).
@ -324,7 +324,14 @@ This serves multiple goals:
identify which parts are making use of it. identify which parts are making use of it.
* The name is shorter, but remains recognizable. * The name is shorter, but remains recognizable.
Because *a lot* of these objects are created, they now define ``__slots__``. This means you can
no longer monkey-patch them to add new attributes at runtime. You have to create a subclass if you
want to define new attributes.
This also means that the updates from ``events.Raw`` **no longer have** ``update._entities``.
// TODO this definitely generated files mapping from the original name to this new one... // TODO this definitely generated files mapping from the original name to this new one...
// TODO what's the alternative to update._entities? and update._client??
Many subpackages and modules are now private Many subpackages and modules are now private

View File

@ -143,22 +143,20 @@ def _handle_update(self: 'TelegramClient', update):
entities = {utils.get_peer_id(x): x for x in entities = {utils.get_peer_id(x): x for x in
itertools.chain(update.users, update.chats)} itertools.chain(update.users, update.chats)}
for u in update.updates: for u in update.updates:
_process_update(self, u, update.updates, entities=entities) _process_update(self, u, entities, update.updates)
elif isinstance(update, _tl.UpdateShort): elif isinstance(update, _tl.UpdateShort):
_process_update(self, update.update, None) _process_update(self, update.update, {}, None)
else: else:
_process_update(self, update, None) _process_update(self, update, {}, None)
self._state_cache.update(update) self._state_cache.update(update)
def _process_update(self: 'TelegramClient', update, others, entities=None): def _process_update(self: 'TelegramClient', update, entities, others):
update._entities = entities or {}
# This part is somewhat hot so we don't bother patching # This part is somewhat hot so we don't bother patching
# update with channel ID/its state. Instead we just pass # update with channel ID/its state. Instead we just pass
# arguments which is faster. # arguments which is faster.
channel_id = self._state_cache.get_channel_id(update) channel_id = self._state_cache.get_channel_id(update)
args = (update, others, channel_id, self._state_cache[channel_id]) args = (update, entities, others, channel_id, self._state_cache[channel_id])
if self._dispatching_updates_queue is None: if self._dispatching_updates_queue is None:
task = self.loop.create_task(_dispatch_update(self, *args)) task = self.loop.create_task(_dispatch_update(self, *args))
self._updates_queue.add(task) self._updates_queue.add(task)
@ -231,10 +229,11 @@ async def _dispatch_queue_updates(self: 'TelegramClient'):
self._dispatching_updates_queue.clear() self._dispatching_updates_queue.clear()
async def _dispatch_update(self: 'TelegramClient', update, others, channel_id, pts_date): async def _dispatch_update(self: 'TelegramClient', update, entities, others, channel_id, pts_date):
entities = self._entity_cache.add(list((update._entities or {}).values()))
if entities: if entities:
await self.session.insert_entities(entities) rows = self._entity_cache.add(list(entities.values()))
if rows:
await self.session.insert_entities(rows)
if not self._entity_cache.ensure_cached(update): if not self._entity_cache.ensure_cached(update):
# We could add a lock to not fetch the same pts twice if we are # We could add a lock to not fetch the same pts twice if we are
@ -244,7 +243,7 @@ async def _dispatch_update(self: 'TelegramClient', update, others, channel_id, p
# If the update doesn't have pts, fetching won't do anything. # If the update doesn't have pts, fetching won't do anything.
# For example, UpdateUserStatus or UpdateChatUserTyping. # For example, UpdateUserStatus or UpdateChatUserTyping.
try: try:
await _get_difference(self, update, channel_id, pts_date) await _get_difference(self, update, entities, channel_id, pts_date)
except OSError: except OSError:
pass # We were disconnected, that's okay pass # We were disconnected, that's okay
except RpcError: except RpcError:
@ -258,7 +257,7 @@ async def _dispatch_update(self: 'TelegramClient', update, others, channel_id, p
# ValueError("Request was unsuccessful N time(s)") for whatever reasons. # ValueError("Request was unsuccessful N time(s)") for whatever reasons.
pass pass
built = EventBuilderDict(self, update, others) built = EventBuilderDict(self, update, entities, others)
for builder, callback in self._event_builders: for builder, callback in self._event_builders:
event = built[type(builder)] event = built[type(builder)]
@ -324,7 +323,7 @@ async def _dispatch_event(self: 'TelegramClient', event):
name = getattr(callback, '__name__', repr(callback)) name = getattr(callback, '__name__', repr(callback))
self._log[__name__].exception('Unhandled exception on %s', name) self._log[__name__].exception('Unhandled exception on %s', name)
async def _get_difference(self: 'TelegramClient', update, channel_id, pts_date): async def _get_difference(self: 'TelegramClient', update, entities, channel_id, pts_date):
""" """
Get the difference for this `channel_id` if any, then load entities. Get the difference for this `channel_id` if any, then load entities.
@ -380,7 +379,7 @@ async def _get_difference(self: 'TelegramClient', update, channel_id, pts_date):
_tl.updates.DifferenceSlice, _tl.updates.DifferenceSlice,
_tl.updates.ChannelDifference, _tl.updates.ChannelDifference,
_tl.updates.ChannelDifferenceTooLong)): _tl.updates.ChannelDifferenceTooLong)):
update._entities.update({ entities.update({
utils.get_peer_id(x): x for x in utils.get_peer_id(x): x for x in
itertools.chain(result.users, result.chats) itertools.chain(result.users, result.chats)
}) })
@ -433,9 +432,10 @@ class EventBuilderDict:
""" """
Helper "dictionary" to return events from types and cache them. Helper "dictionary" to return events from types and cache them.
""" """
def __init__(self, client: 'TelegramClient', update, others): def __init__(self, client: 'TelegramClient', update, entities, others):
self.client = client self.client = client
self.update = update self.update = update
self.entities = entities
self.others = others self.others = others
def __getitem__(self, builder): def __getitem__(self, builder):
@ -447,9 +447,7 @@ class EventBuilderDict:
if isinstance(event, EventCommon): if isinstance(event, EventCommon):
event.original_update = self.update event.original_update = self.update
event._entities = self.update._entities event._entities = self.entities or {}
event._set_client(self.client) event._set_client(self.client)
elif event:
event._client = self.client
return event return event

View File

@ -13,7 +13,7 @@ def _datetime_to_timestamp(dt):
# If no timezone is specified, it is assumed to be in utc zone # If no timezone is specified, it is assumed to be in utc zone
if dt.tzinfo is None: if dt.tzinfo is None:
dt = dt.replace(tzinfo=timezone.utc) dt = dt.replace(tzinfo=timezone.utc)
# We use .total_seconds() method instead of simply dt.timestamp(), # We use .total_seconds() method instead of simply dt.timestamp(),
# because on Windows the latter raises OSError on datetimes ~< datetime(1970,1,1) # because on Windows the latter raises OSError on datetimes ~< datetime(1970,1,1)
secs = int((dt - _EPOCH).total_seconds()) secs = int((dt - _EPOCH).total_seconds())
# Make sure it's a valid signed 32 bit integer, as used by Telegram. # Make sure it's a valid signed 32 bit integer, as used by Telegram.
@ -32,6 +32,7 @@ def _json_default(value):
class TLObject: class TLObject:
__slots__ = ()
CONSTRUCTOR_ID = None CONSTRUCTOR_ID = None
SUBCLASS_OF_ID = None SUBCLASS_OF_ID = None

View File

@ -11,13 +11,10 @@ from ... import _tl
def _fwd(field, doc): def _fwd(field, doc):
def fget(self): def fget(self):
try: return getattr(self._message, field, None)
return self._message.__dict__[field]
except KeyError:
return None
def fset(self, value): def fset(self, value):
self._message.__dict__[field] = value setattr(self._message, field, value)
return property(fget, fset, None, doc) return property(fget, fset, None, doc)

View File

@ -189,6 +189,15 @@ def _write_class_init(tlobject, kind, type_constructors, builder):
builder.writeln() builder.writeln()
builder.writeln('class {}({}):', tlobject.class_name, kind) builder.writeln('class {}({}):', tlobject.class_name, kind)
# Define slots to help reduce the size of the objects a little bit.
# It's also good for knowing what fields an object has.
builder.write('__slots__ = (')
sep = ''
for arg in tlobject.real_args:
builder.write('{}{!r},', sep, arg.name)
sep = ' '
builder.writeln(')')
# Class-level variable to store its Telegram's constructor ID # Class-level variable to store its Telegram's constructor ID
builder.writeln('CONSTRUCTOR_ID = {:#x}', tlobject.id) builder.writeln('CONSTRUCTOR_ID = {:#x}', tlobject.id)
builder.writeln('SUBCLASS_OF_ID = {:#x}', builder.writeln('SUBCLASS_OF_ID = {:#x}',