Create a new in-memory cache for entities (#1141)

This commit is contained in:
Lonami Exo 2019-03-26 11:27:21 +01:00
parent facf3ae582
commit 4d35e8c80f
5 changed files with 62 additions and 16 deletions

View File

@ -2,7 +2,6 @@ import abc
import asyncio import asyncio
import logging import logging
import platform import platform
import sys
import time import time
from datetime import datetime, timezone from datetime import datetime, timezone
@ -13,6 +12,7 @@ from ..network import MTProtoSender, ConnectionTcpFull, TcpMTProxy
from ..sessions import Session, SQLiteSession, MemorySession from ..sessions import Session, SQLiteSession, MemorySession
from ..tl import TLObject, functions, types from ..tl import TLObject, functions, types
from ..tl.alltlobjects import LAYER from ..tl.alltlobjects import LAYER
from ..entitycache import EntityCache
DEFAULT_DC_ID = 4 DEFAULT_DC_ID = 4
DEFAULT_IPV4_IP = '149.154.167.51' DEFAULT_IPV4_IP = '149.154.167.51'
@ -229,14 +229,17 @@ class TelegramBaseClient(abc.ABC):
self.flood_sleep_threshold = flood_sleep_threshold self.flood_sleep_threshold = flood_sleep_threshold
# TODO Figure out how to use AsyncClassWrapper(session) # TODO Use AsyncClassWrapper(session)
# The problem is that ChatGetter and SenderGetter rely # ChatGetter and SenderGetter can use the in-memory _entity_cache
# on synchronous calls to session.get_entity precisely # to avoid network access and the need for await in session files.
# to avoid network access and the need for await.
# #
# With asynchronous sessions, it would need await, # The session files only wants the entities to persist
# and defeats the purpose of properties. # them to disk, and to save additional useful information.
# TODO Make use of _entity_cache
# TODO Session should probably return all cached
# info of entities, not just the input versions
self.session = session self.session = session
self._entity_cache = EntityCache()
self.api_id = int(api_id) self.api_id = int(api_id)
self.api_hash = api_hash self.api_hash = api_hash

View File

@ -179,6 +179,8 @@ class UpdateMethods(UserMethods):
async def _handle_update(self, update): async def _handle_update(self, update):
self.session.process_entities(update) self.session.process_entities(update)
self._entity_cache.add(update)
if isinstance(update, (types.Updates, types.UpdatesCombined)): if isinstance(update, (types.Updates, types.UpdatesCombined)):
entities = {utils.get_peer_id(x): x for x in entities = {utils.get_peer_id(x): x for x in
itertools.chain(update.users, update.chats)} itertools.chain(update.users, update.chats)}

View File

@ -49,6 +49,7 @@ class UserMethods(TelegramBaseClient):
results.append(None) results.append(None)
continue continue
self.session.process_entities(result) self.session.process_entities(result)
self._entity_cache.add(result)
exceptions.append(None) exceptions.append(None)
results.append(result) results.append(result)
request_index += 1 request_index += 1
@ -59,6 +60,7 @@ class UserMethods(TelegramBaseClient):
else: else:
result = await future result = await future
self.session.process_entities(result) self.session.process_entities(result)
self._entity_cache.add(result)
return result return result
except (errors.ServerError, errors.RpcCallFailError, except (errors.ServerError, errors.RpcCallFailError,
errors.RpcMcgetFailError) as e: errors.RpcMcgetFailError) as e:

45
telethon/entitycache.py Normal file
View File

@ -0,0 +1,45 @@
import itertools
from . import utils
from .tl import types
class EntityCache:
"""
In-memory input entity cache, defaultdict-like behaviour.
"""
def add(self, entities):
"""
Adds the given entities to the cache, if they weren't saved before.
"""
if not utils.is_list_like(entities):
# Invariant: all "chats" and "users" are always iterables
entities = itertools.chain(
[getattr(entities, 'user', None)],
getattr(entities, 'chats', []),
getattr(entities, 'users', [])
)
for entity in entities:
try:
pid = utils.get_peer_id(entity)
if pid not in self.__dict__:
# Note: `get_input_peer` already checks for `access_hash`
self.__dict__[pid] = utils.get_input_peer(entity)
except TypeError:
pass
def __getitem__(self, item):
"""
Gets the corresponding :tl:`InputPeer` for the given ID or peer,
or returns `None` on error/not found.
"""
if not isinstance(item, int) or item < 0:
try:
return self.__dict__.get(utils.get_peer_id(item))
except TypeError:
return None
for cls in (types.PeerUser, types.PeerChat, types.PeerChannel):
result = self.__dict__.get(cls(item))
if result:
return result

View File

@ -99,18 +99,12 @@ class MemorySession(Session):
p = utils.get_input_peer(e, allow_self=False) p = utils.get_input_peer(e, allow_self=False)
marked_id = utils.get_peer_id(p) marked_id = utils.get_peer_id(p)
except TypeError: except TypeError:
# Note: `get_input_peer` already checks for
# non-zero `access_hash`. See issues #354 and #392.
return return
if isinstance(p, (InputPeerUser, InputPeerChannel)): if isinstance(p, (InputPeerUser, InputPeerChannel)):
if not p.access_hash: p_hash = p.access_hash
# Some users and channels seem to be returned without
# an 'access_hash', meaning Telegram doesn't want you
# to access them. This is the reason behind ensuring
# that the 'access_hash' is non-zero. See issue #354.
# Note that this checks for zero or None, see #392.
return
else:
p_hash = p.access_hash
elif isinstance(p, InputPeerChat): elif isinstance(p, InputPeerChat):
p_hash = 0 p_hash = 0
else: else: