Attempt at reducing CPU usage after c902428

This attempt removes the forced `await` call, which could
be causing that extra usage. Some more boilerplate is needed.
This commit is contained in:
Lonami Exo 2019-04-23 20:15:27 +02:00
parent 56595e4a9c
commit 1b6b4a57d9
3 changed files with 89 additions and 66 deletions

View File

@ -213,13 +213,17 @@ class UpdateMethods(UserMethods):
self._state_cache.update(update) self._state_cache.update(update)
def _process_update(self, update, entities=None): def _process_update(self, update, entities=None):
update._channel_id = self._state_cache.get_channel_id(update)
update._pts_date = self._state_cache[update._channel_id]
update._entities = entities or {} update._entities = entities or {}
# This part is somewhat hot so we don't bother patching
# update with channel ID/its state. Instead we just pass
# arguments which is faster.
channel_id = self._state_cache.get_channel_id(update)
args = (update, channel_id, self._state_cache[channel_id])
if self._updates_queue is None: if self._updates_queue is None:
self._loop.create_task(self._dispatch_update(update)) self._loop.create_task(self._dispatch_update(*args))
else: else:
self._updates_queue.put_nowait(update) self._updates_queue.put_nowait(args)
if not self._dispatching_updates_queue.is_set(): if not self._dispatching_updates_queue.is_set():
self._dispatching_updates_queue.set() self._dispatching_updates_queue.set()
self._loop.create_task(self._dispatch_queue_updates()) self._loop.create_task(self._dispatch_queue_updates())
@ -273,25 +277,37 @@ class UpdateMethods(UserMethods):
async def _dispatch_queue_updates(self): async def _dispatch_queue_updates(self):
while not self._updates_queue.empty(): while not self._updates_queue.empty():
await self._dispatch_update(self._updates_queue.get_nowait()) await self._dispatch_update(*self._updates_queue.get_nowait())
self._dispatching_updates_queue.clear() self._dispatching_updates_queue.clear()
async def _dispatch_update(self, update): async def _dispatch_update(self, update, channel_id, pts_date):
built = EventBuilderDict(self, update) built = EventBuilderDict(self, update)
if self._conversations: if self._conversations:
for conv in self._conversations.values(): for conv in self._conversations.values():
if await built.get(events.NewMessage): ev = built[events.NewMessage]
conv._on_new_message(built[events.NewMessage]) if ev:
if await built.get(events.MessageEdited): if not ev._load_entities():
conv._on_edit(built[events.MessageEdited]) await ev._get_difference(channel_id, pts_date)
if await built.get(events.MessageRead): conv._on_new_message(ev)
conv._on_read(built[events.MessageRead])
ev = built[events.MessageEdited]
if ev:
if not ev._load_entities():
await ev._get_difference(channel_id, pts_date)
conv._on_edit(ev)
ev = built[events.MessageRead]
if ev:
if not ev._load_entities():
await ev._get_difference(channel_id, pts_date)
conv._on_read(ev)
if conv._custom: if conv._custom:
await conv._check_custom(built) await conv._check_custom(built, channel_id, pts_date)
for builder, callback in self._event_builders: for builder, callback in self._event_builders:
event = await built.get(type(builder)) event = built[type(builder)]
if not event: if not event:
continue continue
@ -302,6 +318,14 @@ class UpdateMethods(UserMethods):
continue continue
try: try:
# Although needing to do this constantly is annoying and
# error-prone, this part is somewhat hot, and always doing
# `await` for `check_entities_and_get_difference` causes
# unnecessary work. So we need to call a function that
# doesn't cause a task switch.
if not event._load_entities():
await event._get_difference(channel_id, pts_date)
await callback(event) await callback(event)
except errors.AlreadyInConversationError: except errors.AlreadyInConversationError:
name = getattr(callback, '__name__', repr(callback)) name = getattr(callback, '__name__', repr(callback))
@ -367,9 +391,6 @@ class EventBuilderDict:
self.update = update self.update = update
def __getitem__(self, builder): def __getitem__(self, builder):
return self.__dict__[builder]
async def get(self, builder):
try: try:
return self.__dict__[builder] return self.__dict__[builder]
except KeyError: except KeyError:
@ -377,54 +398,7 @@ class EventBuilderDict:
if isinstance(event, EventCommon): if isinstance(event, EventCommon):
event.original_update = self.update event.original_update = self.update
event._set_client(self.client) event._set_client(self.client)
if not event._load_entities():
await self.get_difference()
if not event._load_entities():
self.client._log[__name__].info(
'Could not find all entities for update.pts = %s',
getattr(self.update, 'pts', None)
)
elif event: elif event:
# Actually a :tl:`Update`, not much processing to do
event._client = self.client event._client = self.client
return event return event
async def get_difference(self):
"""
Calls :tl:`updates.getDifference`, which fills the entities cache
(always done by `__call__`) and lets us know about the full entities.
"""
# Fetch since the last known pts/date before this update arrived,
# in order to fetch this update at full, including its entities.
self.client._log[__name__].debug('Getting difference for entities')
if self.update._channel_id:
pts = self.update._pts_date
try:
where = await self.client.get_input_entity(self.update._channel_id)
except ValueError:
return
result = await self.client(functions.updates.GetChannelDifferenceRequest(
channel=where,
filter=types.ChannelMessagesFilterEmpty(),
pts=pts,
limit=100,
force=True
))
else:
pts, date = self.update._pts_date
result = await self.client(functions.updates.GetDifferenceRequest(
pts=pts - 1,
date=date,
qts=0
))
if isinstance(result, (types.updates.Difference,
types.updates.DifferenceSlice,
types.updates.ChannelDifference,
types.updates.ChannelDifferenceTooLong)):
self.update._entities.update({
utils.get_peer_id(x): x for x in
itertools.chain(result.users, result.chats)
})

View File

@ -1,9 +1,10 @@
import abc import abc
import asyncio import asyncio
import itertools
import warnings import warnings
from .. import utils from .. import utils
from ..tl import TLObject, types from ..tl import TLObject, types, functions
from ..tl.custom.chatgetter import ChatGetter from ..tl.custom.chatgetter import ChatGetter
@ -174,6 +175,51 @@ class EventCommon(ChatGetter, abc.ABC):
self._chat, self._input_chat = self._get_entity_pair(self.chat_id) self._chat, self._input_chat = self._get_entity_pair(self.chat_id)
return self._input_chat is not None return self._input_chat is not None
async def _get_difference(self, channel_id, pts_date):
"""
Get the difference for this `channel_id` if any, then load entities.
Calls :tl:`updates.getDifference`, which fills the entities cache
(always done by `__call__`) and lets us know about the full entities.
"""
# Fetch since the last known pts/date before this update arrived,
# in order to fetch this update at full, including its entities.
self.client._log[__name__].debug('Getting difference for entities')
if channel_id:
try:
where = await self.client.get_input_entity(channel_id)
except ValueError:
return
result = await self.client(functions.updates.GetChannelDifferenceRequest(
channel=where,
filter=types.ChannelMessagesFilterEmpty(),
pts=pts_date, # just pts
limit=100,
force=True
))
else:
result = await self.client(functions.updates.GetDifferenceRequest(
pts=pts_date[0],
date=pts_date[1],
qts=0
))
if isinstance(result, (types.updates.Difference,
types.updates.DifferenceSlice,
types.updates.ChannelDifference,
types.updates.ChannelDifferenceTooLong)):
self.original_update._entities.update({
utils.get_peer_id(x): x for x in
itertools.chain(result.users, result.chats)
})
if not self._load_entities():
self.client._log[__name__].info(
'Could not find all entities for update.pts = %s',
getattr(self.original_update, 'pts', None)
)
@property @property
def client(self): def client(self):
""" """

View File

@ -294,10 +294,13 @@ class Conversation(ChatGetter):
self._custom[counter] = (event, future) self._custom[counter] = (event, future)
return await result() return await result()
async def _check_custom(self, built): async def _check_custom(self, built, channel_id, pts_date):
for i, (ev, fut) in self._custom.items(): for i, (ev, fut) in self._custom.items():
ev_type = type(ev) ev_type = type(ev)
if built[ev_type] and ev.filter(built[ev_type]): if built[ev_type] and ev.filter(built[ev_type]):
if not ev._load_entities():
await ev._get_difference(channel_id, pts_date)
fut.set_result(built[ev_type]) fut.set_result(built[ev_type])
def _on_new_message(self, response): def _on_new_message(self, response):