mirror of
https://github.com/LonamiWebs/Telethon.git
synced 2024-11-29 21:03:45 +03:00
Remove StateCache and EntityCache
This commit is contained in:
parent
8fc08a0c96
commit
7142734fb4
|
@ -375,8 +375,6 @@ async def log_out(self: 'TelegramClient') -> bool:
|
||||||
except errors.RPCError:
|
except errors.RPCError:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
self._state_cache.reset()
|
|
||||||
|
|
||||||
await self.disconnect()
|
await self.disconnect()
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
|
@ -239,10 +239,6 @@ async def connect(self: 'TelegramClient') -> None:
|
||||||
)
|
)
|
||||||
all_dcs[dc.id] = dc
|
all_dcs[dc.id] = dc
|
||||||
|
|
||||||
# Update state (for catching up after a disconnection)
|
|
||||||
# TODO Get state from channels too
|
|
||||||
self._state_cache = statecache.StateCache(self._session_state, self._log)
|
|
||||||
|
|
||||||
# Use known key, if any
|
# Use known key, if any
|
||||||
self._sender.auth_key.key = dc.auth
|
self._sender.auth_key.key = dc.auth
|
||||||
|
|
||||||
|
@ -351,10 +347,6 @@ async def _disconnect_coro(self: 'TelegramClient'):
|
||||||
await asyncio.wait(self._updates_queue)
|
await asyncio.wait(self._updates_queue)
|
||||||
self._updates_queue.clear()
|
self._updates_queue.clear()
|
||||||
|
|
||||||
pts, date = self._state_cache[None]
|
|
||||||
if pts and date:
|
|
||||||
if self._session_state:
|
|
||||||
await self._replace_session_state(pts=pts, date=date)
|
|
||||||
|
|
||||||
async def _disconnect(self: 'TelegramClient'):
|
async def _disconnect(self: 'TelegramClient'):
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -79,10 +79,7 @@ def list_event_handlers(self: 'TelegramClient')\
|
||||||
return [(callback, event) for event, callback in self._event_builders]
|
return [(callback, event) for event, callback in self._event_builders]
|
||||||
|
|
||||||
async def catch_up(self: 'TelegramClient'):
|
async def catch_up(self: 'TelegramClient'):
|
||||||
pts, date = self._state_cache[None]
|
|
||||||
if not pts:
|
|
||||||
return
|
return
|
||||||
|
|
||||||
self._catching_up = True
|
self._catching_up = True
|
||||||
try:
|
try:
|
||||||
while True:
|
while True:
|
||||||
|
@ -131,8 +128,6 @@ async def catch_up(self: 'TelegramClient'):
|
||||||
except (ConnectionError, asyncio.CancelledError):
|
except (ConnectionError, asyncio.CancelledError):
|
||||||
pass
|
pass
|
||||||
finally:
|
finally:
|
||||||
# TODO Save new pts to session
|
|
||||||
self._state_cache._pts_date = (pts, date)
|
|
||||||
self._catching_up = False
|
self._catching_up = False
|
||||||
|
|
||||||
|
|
||||||
|
@ -150,14 +145,12 @@ def _handle_update(self: 'TelegramClient', update):
|
||||||
else:
|
else:
|
||||||
_process_update(self, update, {}, None)
|
_process_update(self, update, {}, None)
|
||||||
|
|
||||||
self._state_cache.update(update)
|
|
||||||
|
|
||||||
def _process_update(self: 'TelegramClient', update, entities, others):
|
def _process_update(self: 'TelegramClient', update, entities, others):
|
||||||
# This part is somewhat hot so we don't bother patching
|
# This part is somewhat hot so we don't bother patching
|
||||||
# update with channel ID/its state. Instead we just pass
|
# update with channel ID/its state. Instead we just pass
|
||||||
# arguments which is faster.
|
# arguments which is faster.
|
||||||
channel_id = self._state_cache.get_channel_id(update)
|
args = (update, entities, others, channel_id, None)
|
||||||
args = (update, entities, others, channel_id, self._state_cache[channel_id])
|
|
||||||
if self._dispatching_updates_queue is None:
|
if self._dispatching_updates_queue is None:
|
||||||
task = asyncio.create_task(_dispatch_update(self, *args))
|
task = asyncio.create_task(_dispatch_update(self, *args))
|
||||||
self._updates_queue.add(task)
|
self._updates_queue.add(task)
|
||||||
|
@ -168,8 +161,6 @@ def _process_update(self: 'TelegramClient', update, entities, others):
|
||||||
self._dispatching_updates_queue.set()
|
self._dispatching_updates_queue.set()
|
||||||
asyncio.create_task(_dispatch_queue_updates(self))
|
asyncio.create_task(_dispatch_queue_updates(self))
|
||||||
|
|
||||||
self._state_cache.update(update)
|
|
||||||
|
|
||||||
async def _update_loop(self: 'TelegramClient'):
|
async def _update_loop(self: 'TelegramClient'):
|
||||||
# Pings' ID don't really need to be secure, just "random"
|
# Pings' ID don't really need to be secure, just "random"
|
||||||
rnd = lambda: random.randrange(-2**63, 2**63)
|
rnd = lambda: random.randrange(-2**63, 2**63)
|
||||||
|
@ -326,7 +317,6 @@ async def _get_difference(self: 'TelegramClient', update, entities, channel_id,
|
||||||
result = await self(_tl.fn.channels.GetFullChannel(
|
result = await self(_tl.fn.channels.GetFullChannel(
|
||||||
utils.get_input_channel(where)
|
utils.get_input_channel(where)
|
||||||
))
|
))
|
||||||
self._state_cache[channel_id] = result.full_chat.pts
|
|
||||||
return
|
return
|
||||||
|
|
||||||
result = await self(_tl.fn.updates.GetChannelDifference(
|
result = await self(_tl.fn.updates.GetChannelDifference(
|
||||||
|
@ -340,7 +330,6 @@ async def _get_difference(self: 'TelegramClient', update, entities, channel_id,
|
||||||
if not pts_date[0]:
|
if not pts_date[0]:
|
||||||
# First-time, can't get difference. Get pts instead.
|
# First-time, can't get difference. Get pts instead.
|
||||||
result = await self(_tl.fn.updates.GetState())
|
result = await self(_tl.fn.updates.GetState())
|
||||||
self._state_cache[None] = result.pts, result.date
|
|
||||||
return
|
return
|
||||||
|
|
||||||
result = await self(_tl.fn.updates.GetDifference(
|
result = await self(_tl.fn.updates.GetDifference(
|
||||||
|
|
|
@ -1,179 +0,0 @@
|
||||||
import inspect
|
|
||||||
import itertools
|
|
||||||
|
|
||||||
from .._misc import utils
|
|
||||||
from .. import _tl
|
|
||||||
from .._sessions.types import EntityType, Entity
|
|
||||||
|
|
||||||
# Which updates have the following fields?
|
|
||||||
_has_field = {
|
|
||||||
('user_id', int): [],
|
|
||||||
('chat_id', int): [],
|
|
||||||
('channel_id', int): [],
|
|
||||||
('peer', 'TypePeer'): [],
|
|
||||||
('peer', 'TypeDialogPeer'): [],
|
|
||||||
('message', 'TypeMessage'): [],
|
|
||||||
}
|
|
||||||
|
|
||||||
# Note: We don't bother checking for some rare:
|
|
||||||
# * `UpdateChatParticipantAdd.inviter_id` integer.
|
|
||||||
# * `UpdateNotifySettings.peer` dialog peer.
|
|
||||||
# * `UpdatePinnedDialogs.order` list of dialog peers.
|
|
||||||
# * `UpdateReadMessagesContents.messages` list of messages.
|
|
||||||
# * `UpdateChatParticipants.participants` list of participants.
|
|
||||||
#
|
|
||||||
# There are also some uninteresting `update.message` of type string.
|
|
||||||
|
|
||||||
|
|
||||||
def _fill():
|
|
||||||
for name in dir(_tl):
|
|
||||||
update = getattr(_tl, name)
|
|
||||||
if getattr(update, 'SUBCLASS_OF_ID', None) == 0x9f89304e:
|
|
||||||
cid = update.CONSTRUCTOR_ID
|
|
||||||
sig = inspect.signature(update.__init__)
|
|
||||||
for param in sig.parameters.values():
|
|
||||||
vec = _has_field.get((param.name, param.annotation))
|
|
||||||
if vec is not None:
|
|
||||||
vec.append(cid)
|
|
||||||
|
|
||||||
# Future-proof check: if the documentation format ever changes
|
|
||||||
# then we won't be able to pick the update types we are interested
|
|
||||||
# in, so we must make sure we have at least an update for each field
|
|
||||||
# which likely means we are doing it right.
|
|
||||||
if not all(_has_field.values()):
|
|
||||||
raise RuntimeError('FIXME: Did the init signature or updates change?')
|
|
||||||
|
|
||||||
|
|
||||||
# We use a function to avoid cluttering the globals (with name/update/cid/doc)
|
|
||||||
_fill()
|
|
||||||
|
|
||||||
|
|
||||||
class EntityCache:
|
|
||||||
"""
|
|
||||||
In-memory input entity cache, defaultdict-like behaviour.
|
|
||||||
"""
|
|
||||||
def add(self, entities, _mappings={
|
|
||||||
_tl.User.CONSTRUCTOR_ID: lambda e: (EntityType.BOT if e.bot else EntityType.USER, e.id, e.access_hash),
|
|
||||||
_tl.UserFull.CONSTRUCTOR_ID: lambda e: (EntityType.BOT if e.user.bot else EntityType.USER, e.user.id, e.user.access_hash),
|
|
||||||
_tl.Chat.CONSTRUCTOR_ID: lambda e: (EntityType.GROUP, e.id, 0),
|
|
||||||
_tl.ChatFull.CONSTRUCTOR_ID: lambda e: (EntityType.GROUP, e.id, 0),
|
|
||||||
_tl.ChatEmpty.CONSTRUCTOR_ID: lambda e: (EntityType.GROUP, e.id, 0),
|
|
||||||
_tl.ChatForbidden.CONSTRUCTOR_ID: lambda e: (EntityType.GROUP, e.id, 0),
|
|
||||||
_tl.Channel.CONSTRUCTOR_ID: lambda e: (
|
|
||||||
EntityType.MEGAGROUP if e.megagroup else (EntityType.GIGAGROUP if e.gigagroup else EntityType.CHANNEL),
|
|
||||||
e.id,
|
|
||||||
e.access_hash,
|
|
||||||
),
|
|
||||||
_tl.ChannelForbidden.CONSTRUCTOR_ID: lambda e: (EntityType.MEGAGROUP if e.megagroup else EntityType.CHANNEL, e.id, e.access_hash),
|
|
||||||
}):
|
|
||||||
"""
|
|
||||||
Adds the given entities to the cache, if they weren't saved before.
|
|
||||||
|
|
||||||
Returns a list of Entity that can be saved in the session.
|
|
||||||
"""
|
|
||||||
if not utils.is_list_like(entities):
|
|
||||||
# Invariant: all "chats" and "users" are always iterables,
|
|
||||||
# and "user" and "chat" never are (so we wrap them inside a list).
|
|
||||||
#
|
|
||||||
# Itself may be already the entity we want to cache.
|
|
||||||
entities = itertools.chain(
|
|
||||||
[entities],
|
|
||||||
getattr(entities, 'chats', []),
|
|
||||||
getattr(entities, 'users', []),
|
|
||||||
(hasattr(entities, 'user') and [entities.user]) or [],
|
|
||||||
(hasattr(entities, 'chat') and [entities.user]) or [],
|
|
||||||
)
|
|
||||||
|
|
||||||
rows = []
|
|
||||||
for e in entities:
|
|
||||||
try:
|
|
||||||
mapper = _mappings[e.CONSTRUCTOR_ID]
|
|
||||||
except (AttributeError, KeyError):
|
|
||||||
continue
|
|
||||||
|
|
||||||
ty, id, access_hash = mapper(e)
|
|
||||||
|
|
||||||
# Need to check for non-zero access hash unless it's a group (#354 and #392).
|
|
||||||
# Also check it's not `min` (`access_hash` usage is limited since layer 102).
|
|
||||||
if not getattr(e, 'min', False) and (access_hash or ty == Entity.GROUP):
|
|
||||||
rows.append(Entity(ty, id, access_hash))
|
|
||||||
if id not in self.__dict__:
|
|
||||||
if ty in (EntityType.USER, EntityType.BOT):
|
|
||||||
self.__dict__[id] = _tl.InputPeerUser(id, access_hash)
|
|
||||||
elif ty in (EntityType.GROUP,):
|
|
||||||
self.__dict__[id] = _tl.InputPeerChat(id)
|
|
||||||
elif ty in (EntityType.CHANNEL, EntityType.MEGAGROUP, EntityType.GIGAGROUP):
|
|
||||||
self.__dict__[id] = _tl.InputPeerChannel(id, access_hash)
|
|
||||||
|
|
||||||
return rows
|
|
||||||
|
|
||||||
def __getitem__(self, item):
|
|
||||||
"""
|
|
||||||
Gets the corresponding :tl:`InputPeer` for the given ID or peer,
|
|
||||||
or raises ``KeyError`` on any error (i.e. cannot be found).
|
|
||||||
"""
|
|
||||||
if not isinstance(item, int) or item < 0:
|
|
||||||
try:
|
|
||||||
return self.__dict__[utils.get_peer_id(item)]
|
|
||||||
except TypeError:
|
|
||||||
raise KeyError('Invalid key will not have entity') from None
|
|
||||||
|
|
||||||
for cls in (_tl.PeerUser, _tl.PeerChat, _tl.PeerChannel):
|
|
||||||
result = self.__dict__.get(utils.get_peer_id(cls(item)))
|
|
||||||
if result:
|
|
||||||
return result
|
|
||||||
|
|
||||||
raise KeyError('No cached entity for the given key')
|
|
||||||
|
|
||||||
def clear(self):
|
|
||||||
"""
|
|
||||||
Clear the entity cache.
|
|
||||||
"""
|
|
||||||
self.__dict__.clear()
|
|
||||||
|
|
||||||
def ensure_cached(
|
|
||||||
self,
|
|
||||||
update,
|
|
||||||
has_user_id=frozenset(_has_field[('user_id', int)]),
|
|
||||||
has_chat_id=frozenset(_has_field[('chat_id', int)]),
|
|
||||||
has_channel_id=frozenset(_has_field[('channel_id', int)]),
|
|
||||||
has_peer=frozenset(_has_field[('peer', 'TypePeer')] + _has_field[('peer', 'TypeDialogPeer')]),
|
|
||||||
has_message=frozenset(_has_field[('message', 'TypeMessage')])
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Ensures that all the relevant entities in the given update are cached.
|
|
||||||
"""
|
|
||||||
# This method is called pretty often and we want it to have the lowest
|
|
||||||
# overhead possible. For that, we avoid `isinstance` and constantly
|
|
||||||
# getting attributes out of `_tl.` by "caching" the constructor IDs
|
|
||||||
# in sets inside the arguments, and using local variables.
|
|
||||||
dct = self.__dict__
|
|
||||||
cid = update.CONSTRUCTOR_ID
|
|
||||||
if cid in has_user_id and \
|
|
||||||
update.user_id not in dct:
|
|
||||||
return False
|
|
||||||
|
|
||||||
if cid in has_chat_id and update.chat_id not in dct:
|
|
||||||
return False
|
|
||||||
|
|
||||||
if cid in has_channel_id and update.channel_id not in dct:
|
|
||||||
return False
|
|
||||||
|
|
||||||
if cid in has_peer and \
|
|
||||||
utils.get_peer_id(update.peer) not in dct:
|
|
||||||
return False
|
|
||||||
|
|
||||||
if cid in has_message:
|
|
||||||
x = update.message
|
|
||||||
y = getattr(x, 'peer_id', None) # handle MessageEmpty
|
|
||||||
if y and utils.get_peer_id(y) not in dct:
|
|
||||||
return False
|
|
||||||
|
|
||||||
y = getattr(x, 'from_id', None)
|
|
||||||
if y and utils.get_peer_id(y) not in dct:
|
|
||||||
return False
|
|
||||||
|
|
||||||
# We don't quite worry about entities anywhere else.
|
|
||||||
# This is enough.
|
|
||||||
|
|
||||||
return True
|
|
|
@ -1,164 +0,0 @@
|
||||||
import inspect
|
|
||||||
|
|
||||||
from .. import _tl
|
|
||||||
|
|
||||||
|
|
||||||
# Which updates have the following fields?
|
|
||||||
_has_channel_id = []
|
|
||||||
|
|
||||||
|
|
||||||
# TODO EntityCache does the same. Reuse?
|
|
||||||
def _fill():
|
|
||||||
for name in dir(_tl):
|
|
||||||
update = getattr(_tl, name)
|
|
||||||
if getattr(update, 'SUBCLASS_OF_ID', None) == 0x9f89304e:
|
|
||||||
cid = update.CONSTRUCTOR_ID
|
|
||||||
sig = inspect.signature(update.__init__)
|
|
||||||
for param in sig.parameters.values():
|
|
||||||
if param.name == 'channel_id' and param.annotation == int:
|
|
||||||
_has_channel_id.append(cid)
|
|
||||||
|
|
||||||
if not _has_channel_id:
|
|
||||||
raise RuntimeError('FIXME: Did the init signature or updates change?')
|
|
||||||
|
|
||||||
|
|
||||||
# We use a function to avoid cluttering the globals (with name/update/cid/doc)
|
|
||||||
_fill()
|
|
||||||
|
|
||||||
|
|
||||||
class StateCache:
|
|
||||||
"""
|
|
||||||
In-memory update state cache, defaultdict-like behaviour.
|
|
||||||
"""
|
|
||||||
def __init__(self, initial, loggers):
|
|
||||||
# We only care about the pts and the date. By using a tuple which
|
|
||||||
# is lightweight and immutable we can easily copy them around to
|
|
||||||
# each update in case they need to fetch missing entities.
|
|
||||||
self._logger = loggers[__name__]
|
|
||||||
if initial:
|
|
||||||
self._pts_date = initial.pts or None, initial.date or None
|
|
||||||
else:
|
|
||||||
self._pts_date = None, None
|
|
||||||
|
|
||||||
def reset(self):
|
|
||||||
self.__dict__.clear()
|
|
||||||
self._pts_date = None, None
|
|
||||||
|
|
||||||
# TODO Call this when receiving responses too...?
|
|
||||||
def update(
|
|
||||||
self,
|
|
||||||
update,
|
|
||||||
*,
|
|
||||||
channel_id=None,
|
|
||||||
has_pts=frozenset(x.CONSTRUCTOR_ID for x in (
|
|
||||||
_tl.UpdateNewMessage,
|
|
||||||
_tl.UpdateDeleteMessages,
|
|
||||||
_tl.UpdateReadHistoryInbox,
|
|
||||||
_tl.UpdateReadHistoryOutbox,
|
|
||||||
_tl.UpdateWebPage,
|
|
||||||
_tl.UpdateReadMessagesContents,
|
|
||||||
_tl.UpdateEditMessage,
|
|
||||||
_tl.updates.State,
|
|
||||||
_tl.updates.DifferenceTooLong,
|
|
||||||
_tl.UpdateShortMessage,
|
|
||||||
_tl.UpdateShortChatMessage,
|
|
||||||
_tl.UpdateShortSentMessage
|
|
||||||
)),
|
|
||||||
has_date=frozenset(x.CONSTRUCTOR_ID for x in (
|
|
||||||
_tl.UpdateUserPhoto,
|
|
||||||
_tl.UpdateEncryption,
|
|
||||||
_tl.UpdateEncryptedMessagesRead,
|
|
||||||
_tl.UpdateChatParticipantAdd,
|
|
||||||
_tl.updates.DifferenceEmpty,
|
|
||||||
_tl.UpdateShortMessage,
|
|
||||||
_tl.UpdateShortChatMessage,
|
|
||||||
_tl.UpdateShort,
|
|
||||||
_tl.UpdatesCombined,
|
|
||||||
_tl.Updates,
|
|
||||||
_tl.UpdateShortSentMessage,
|
|
||||||
)),
|
|
||||||
has_channel_pts=frozenset(x.CONSTRUCTOR_ID for x in (
|
|
||||||
_tl.UpdateChannelTooLong,
|
|
||||||
_tl.UpdateNewChannelMessage,
|
|
||||||
_tl.UpdateDeleteChannelMessages,
|
|
||||||
_tl.UpdateEditChannelMessage,
|
|
||||||
_tl.UpdateChannelWebPage,
|
|
||||||
_tl.updates.ChannelDifferenceEmpty,
|
|
||||||
_tl.updates.ChannelDifferenceTooLong,
|
|
||||||
_tl.updates.ChannelDifference
|
|
||||||
)),
|
|
||||||
check_only=False
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Update the state with the given update.
|
|
||||||
"""
|
|
||||||
cid = update.CONSTRUCTOR_ID
|
|
||||||
if check_only:
|
|
||||||
return cid in has_pts or cid in has_date or cid in has_channel_pts
|
|
||||||
|
|
||||||
if cid in has_pts:
|
|
||||||
if cid in has_date:
|
|
||||||
self._pts_date = update.pts, update.date
|
|
||||||
else:
|
|
||||||
self._pts_date = update.pts, self._pts_date[1]
|
|
||||||
elif cid in has_date:
|
|
||||||
self._pts_date = self._pts_date[0], update.date
|
|
||||||
|
|
||||||
if cid in has_channel_pts:
|
|
||||||
if channel_id is None:
|
|
||||||
channel_id = self.get_channel_id(update)
|
|
||||||
|
|
||||||
if channel_id is None:
|
|
||||||
self._logger.info(
|
|
||||||
'Failed to retrieve channel_id from %s', update)
|
|
||||||
else:
|
|
||||||
self.__dict__[channel_id] = update.pts
|
|
||||||
|
|
||||||
def get_channel_id(
|
|
||||||
self,
|
|
||||||
update,
|
|
||||||
has_channel_id=frozenset(_has_channel_id),
|
|
||||||
# Hardcoded because only some with message are for channels
|
|
||||||
has_message=frozenset(x.CONSTRUCTOR_ID for x in (
|
|
||||||
_tl.UpdateNewChannelMessage,
|
|
||||||
_tl.UpdateEditChannelMessage
|
|
||||||
))
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Gets the **unmarked** channel ID from this update, if it has any.
|
|
||||||
|
|
||||||
Fails for ``*difference`` updates, where ``channel_id``
|
|
||||||
is supposedly already known from the outside.
|
|
||||||
"""
|
|
||||||
cid = update.CONSTRUCTOR_ID
|
|
||||||
if cid in has_channel_id:
|
|
||||||
return update.channel_id
|
|
||||||
elif cid in has_message:
|
|
||||||
if update.message.peer_id is None:
|
|
||||||
# Telegram sometimes sends empty messages to give a newer pts:
|
|
||||||
# UpdateNewChannelMessage(message=MessageEmpty(id), pts=pts, pts_count=1)
|
|
||||||
# Not sure why, but it's safe to ignore them.
|
|
||||||
self._logger.debug('Update has None peer_id %s', update)
|
|
||||||
else:
|
|
||||||
return update.message.peer_id.channel_id
|
|
||||||
|
|
||||||
return None
|
|
||||||
|
|
||||||
def __getitem__(self, item):
|
|
||||||
"""
|
|
||||||
If `item` is `None`, returns the default ``(pts, date)``.
|
|
||||||
|
|
||||||
If it's an **unmarked** channel ID, returns its ``pts``.
|
|
||||||
|
|
||||||
If no information is known, ``pts`` will be `None`.
|
|
||||||
"""
|
|
||||||
if item is None:
|
|
||||||
return self._pts_date
|
|
||||||
else:
|
|
||||||
return self.__dict__.get(item)
|
|
||||||
|
|
||||||
def __setitem__(self, where, value):
|
|
||||||
if where is None:
|
|
||||||
self._pts_date = value
|
|
||||||
else:
|
|
||||||
self.__dict__[where] = value
|
|
Loading…
Reference in New Issue
Block a user