Remove redundant entity cache

Progress towards #3989.
May also help with #3235.
This commit is contained in:
Lonami Exo 2023-04-06 13:09:07 +02:00
parent 3e64ea35ff
commit f7e38ee6f0
14 changed files with 35 additions and 32 deletions

View File

@ -300,7 +300,7 @@ class TelegramBaseClient(abc.ABC):
self.flood_sleep_threshold = flood_sleep_threshold self.flood_sleep_threshold = flood_sleep_threshold
# TODO Use AsyncClassWrapper(session) # TODO Use AsyncClassWrapper(session)
# ChatGetter and SenderGetter can use the in-memory _entity_cache # ChatGetter and SenderGetter can use the in-memory _mb_entity_cache
# to avoid network access and the need for await in session files. # to avoid network access and the need for await in session files.
# #
# The session files only wants the entities to persist # The session files only wants the entities to persist
@ -308,7 +308,6 @@ class TelegramBaseClient(abc.ABC):
# TODO Session should probably return all cached # TODO Session should probably return all cached
# info of entities, not just the input versions # info of entities, not just the input versions
self.session = session self.session = session
self._entity_cache = EntityCache()
self.api_id = int(api_id) self.api_id = int(api_id)
self.api_hash = api_hash self.api_hash = api_hash
@ -433,7 +432,6 @@ class TelegramBaseClient(abc.ABC):
self._catch_up = catch_up self._catch_up = catch_up
self._updates_queue = asyncio.Queue() self._updates_queue = asyncio.Queue()
self._message_box = MessageBox(self._log['messagebox']) self._message_box = MessageBox(self._log['messagebox'])
# This entity cache is tailored for the messagebox and is not used for absolutely everything like _entity_cache
self._mb_entity_cache = MbEntityCache() # required for proper update handling (to know when to getDifference) self._mb_entity_cache = MbEntityCache() # required for proper update handling (to know when to getDifference)
self._sender = MTProtoSender( self._sender = MTProtoSender(

View File

@ -72,7 +72,6 @@ class UserMethods:
results.append(None) results.append(None)
continue continue
self.session.process_entities(result) self.session.process_entities(result)
self._entity_cache.add(result)
exceptions.append(None) exceptions.append(None)
results.append(result) results.append(result)
request_index += 1 request_index += 1
@ -83,7 +82,6 @@ class UserMethods:
else: else:
result = await future result = await future
self.session.process_entities(result) self.session.process_entities(result)
self._entity_cache.add(result)
return result return result
except (errors.ServerError, errors.RpcCallFailError, except (errors.ServerError, errors.RpcCallFailError,
errors.RpcMcgetFailError, errors.InterdcCallErrorError, errors.RpcMcgetFailError, errors.InterdcCallErrorError,
@ -417,8 +415,8 @@ class UserMethods:
try: try:
# 0x2d45687 == crc32(b'Peer') # 0x2d45687 == crc32(b'Peer')
if isinstance(peer, int) or peer.SUBCLASS_OF_ID == 0x2d45687: if isinstance(peer, int) or peer.SUBCLASS_OF_ID == 0x2d45687:
return self._entity_cache[peer] return self._mb_entity_cache.get(utils.get_peer_id(peer, add_mark=False))._as_input_peer()
except (AttributeError, KeyError): except AttributeError:
pass pass
# Then come known strings that take precedence # Then come known strings that take precedence

View File

@ -160,7 +160,7 @@ class Album(EventBuilder):
def _set_client(self, client): def _set_client(self, client):
super()._set_client(client) super()._set_client(client)
self._sender, self._input_sender = utils._get_entity_pair( self._sender, self._input_sender = utils._get_entity_pair(
self.sender_id, self._entities, client._entity_cache) self.sender_id, self._entities, client._mb_entity_cache)
for msg in self.messages: for msg in self.messages:
msg._finish_init(client, self._entities, None) msg._finish_init(client, self._entities, None)

View File

@ -151,7 +151,7 @@ class CallbackQuery(EventBuilder):
def _set_client(self, client): def _set_client(self, client):
super()._set_client(client) super()._set_client(client)
self._sender, self._input_sender = utils._get_entity_pair( self._sender, self._input_sender = utils._get_entity_pair(
self.sender_id, self._entities, client._entity_cache) self.sender_id, self._entities, client._mb_entity_cache)
@property @property
def id(self): def id(self):
@ -208,8 +208,9 @@ class CallbackQuery(EventBuilder):
if not getattr(self._input_sender, 'access_hash', True): if not getattr(self._input_sender, 'access_hash', True):
# getattr with True to handle the InputPeerSelf() case # getattr with True to handle the InputPeerSelf() case
try: try:
self._input_sender = self._client._entity_cache[self._sender_id] self._input_sender = self._client._mb_entity_cache.get(
except KeyError: utils.resolve_id(self._sender_id)[0])._as_input_peer()
except AttributeError:
m = await self.get_message() m = await self.get_message()
if m: if m:
self._sender = m._sender self._sender = m._sender

View File

@ -425,9 +425,10 @@ class ChatAction(EventBuilder):
# If missing, try from the entity cache # If missing, try from the entity cache
try: try:
self._input_users.append(self._client._entity_cache[user_id]) self._input_users.append(self._client._mb_entity_cache.get(
utils.resolve_id(user_id)[0])._as_input_peer())
continue continue
except KeyError: except AttributeError:
pass pass
return self._input_users or [] return self._input_users or []

View File

@ -154,7 +154,7 @@ class EventCommon(ChatGetter, abc.ABC):
self._client = client self._client = client
if self._chat_peer: if self._chat_peer:
self._chat, self._input_chat = utils._get_entity_pair( self._chat, self._input_chat = utils._get_entity_pair(
self.chat_id, self._entities, client._entity_cache) self.chat_id, self._entities, client._mb_entity_cache)
else: else:
self._chat = self._input_chat = None self._chat = self._input_chat = None

View File

@ -99,7 +99,7 @@ class InlineQuery(EventBuilder):
def _set_client(self, client): def _set_client(self, client):
super()._set_client(client) super()._set_client(client)
self._sender, self._input_sender = utils._get_entity_pair( self._sender, self._input_sender = utils._get_entity_pair(
self.sender_id, self._entities, client._entity_cache) self.sender_id, self._entities, client._mb_entity_cache)
@property @property
def id(self): def id(self):

View File

@ -95,7 +95,7 @@ class UserUpdate(EventBuilder):
def _set_client(self, client): def _set_client(self, client):
super()._set_client(client) super()._set_client(client)
self._sender, self._input_sender = utils._get_entity_pair( self._sender, self._input_sender = utils._get_entity_pair(
self.sender_id, self._entities, client._entity_cache) self.sender_id, self._entities, client._mb_entity_cache)
@property @property
def user(self): def user(self):

View File

@ -66,8 +66,9 @@ class ChatGetter(abc.ABC):
""" """
if self._input_chat is None and self._chat_peer and self._client: if self._input_chat is None and self._chat_peer and self._client:
try: try:
self._input_chat = self._client._entity_cache[self._chat_peer] self._input_chat = self._client._mb_entity_cache.get(
except KeyError: utils.get_peer_id(self._chat_peer, add_mark=False))._as_input_peer()
except AttributeError:
pass pass
return self._input_chat return self._input_chat

View File

@ -5,7 +5,7 @@ from ..functions.messages import SaveDraftRequest
from ..types import DraftMessage from ..types import DraftMessage
from ...errors import RPCError from ...errors import RPCError
from ...extensions import markdown from ...extensions import markdown
from ...utils import get_input_peer, get_peer from ...utils import get_input_peer, get_peer, get_peer_id
class Draft: class Draft:
@ -53,8 +53,9 @@ class Draft:
""" """
if not self._input_entity: if not self._input_entity:
try: try:
self._input_entity = self._client._entity_cache[self._peer] self._input_entity = self._client._mb_entity_cache.get(
except KeyError: get_peer_id(self._peer, add_mark=False))._as_input_peer()
except AttributeError:
pass pass
return self._input_entity return self._input_entity

View File

@ -36,12 +36,12 @@ class Forward(ChatGetter, SenderGetter):
if ty == helpers._EntityType.USER: if ty == helpers._EntityType.USER:
sender_id = utils.get_peer_id(original.from_id) sender_id = utils.get_peer_id(original.from_id)
sender, input_sender = utils._get_entity_pair( sender, input_sender = utils._get_entity_pair(
sender_id, entities, client._entity_cache) sender_id, entities, client._mb_entity_cache)
elif ty in (helpers._EntityType.CHAT, helpers._EntityType.CHANNEL): elif ty in (helpers._EntityType.CHAT, helpers._EntityType.CHANNEL):
peer = original.from_id peer = original.from_id
chat, input_chat = utils._get_entity_pair( chat, input_chat = utils._get_entity_pair(
utils.get_peer_id(peer), entities, client._entity_cache) utils.get_peer_id(peer), entities, client._mb_entity_cache)
# This call resets the client # This call resets the client
ChatGetter.__init__(self, peer, chat=chat, input_chat=input_chat) ChatGetter.__init__(self, peer, chat=chat, input_chat=input_chat)

View File

@ -285,7 +285,7 @@ class Message(ChatGetter, SenderGetter, TLObject):
if self.peer_id == types.PeerUser(client._self_id) and not self.fwd_from: if self.peer_id == types.PeerUser(client._self_id) and not self.fwd_from:
self.out = True self.out = True
cache = client._entity_cache cache = client._mb_entity_cache
self._sender, self._input_sender = utils._get_entity_pair( self._sender, self._input_sender = utils._get_entity_pair(
self.sender_id, entities, cache) self.sender_id, entities, cache)
@ -1138,8 +1138,9 @@ class Message(ChatGetter, SenderGetter, TLObject):
return bot return bot
else: else:
try: try:
return self._client._entity_cache[self.via_bot_id] return self._client._mb_entity_cache.get(
except KeyError: utils.resolve_id(self.via_bot_id)[0])._as_input_peer()
except AttributeError:
raise ValueError('No input sender') from None raise ValueError('No input sender') from None
def _document_by_attribute(self, kind, condition=None): def _document_by_attribute(self, kind, condition=None):

View File

@ -1,5 +1,7 @@
import abc import abc
from ... import utils
class SenderGetter(abc.ABC): class SenderGetter(abc.ABC):
""" """
@ -69,9 +71,9 @@ class SenderGetter(abc.ABC):
""" """
if self._input_sender is None and self._sender_id and self._client: if self._input_sender is None and self._sender_id and self._client:
try: try:
self._input_sender = \ self._input_sender = self._client._mb_entity_cache.get(
self._client._entity_cache[self._sender_id] utils.resolve_id(self._sender_id)[0])._as_input_peer()
except KeyError: except AttributeError:
pass pass
return self._input_sender return self._input_sender

View File

@ -585,9 +585,9 @@ def _get_entity_pair(entity_id, entities, cache,
""" """
entity = entities.get(entity_id) entity = entities.get(entity_id)
try: try:
input_entity = cache[entity_id] input_entity = cache.get(resolve_id(entity_id)[0])._as_input_peer()
except KeyError: except AttributeError:
# KeyError is unlikely, so another TypeError won't hurt # AttributeError is unlikely, so another TypeError won't hurt
try: try:
input_entity = get_input_peer(entity) input_entity = get_input_peer(entity)
except TypeError: except TypeError: