Make use of the new MessageBox

This commit is contained in:
Lonami Exo 2022-05-13 13:17:16 +02:00
parent b5bfe5d9a1
commit db09a92bc5
3 changed files with 101 additions and 230 deletions

View File

@ -15,6 +15,7 @@ from ..network import MTProtoSender, Connection, ConnectionTcpFull, TcpMTProxy
from ..sessions import Session, SQLiteSession, MemorySession from ..sessions import Session, SQLiteSession, MemorySession
from ..tl import functions, types from ..tl import functions, types
from ..tl.alltlobjects import LAYER from ..tl.alltlobjects import LAYER
from .._updates import MessageBox, EntityCache as MbEntityCache
DEFAULT_DC_ID = 2 DEFAULT_DC_ID = 2
DEFAULT_IPV4_IP = '149.154.167.51' DEFAULT_IPV4_IP = '149.154.167.51'
@ -376,18 +377,6 @@ class TelegramBaseClient(abc.ABC):
proxy=init_proxy proxy=init_proxy
) )
self._sender = MTProtoSender(
self.session.auth_key,
loggers=self._log,
retries=self._connection_retries,
delay=self._retry_delay,
auto_reconnect=self._auto_reconnect,
connect_timeout=self._timeout,
auth_key_callback=self._auth_key_callback,
update_callback=self._handle_update,
auto_reconnect_callback=self._handle_auto_reconnect
)
# Remember flood-waited requests to avoid making them again # Remember flood-waited requests to avoid making them again
self._flood_waited_requests = {} self._flood_waited_requests = {}
@ -396,18 +385,14 @@ class TelegramBaseClient(abc.ABC):
self._borrow_sender_lock = asyncio.Lock() self._borrow_sender_lock = asyncio.Lock()
self._updates_handle = None self._updates_handle = None
self._keepalive_handle = None
self._last_request = time.time() self._last_request = time.time()
self._channel_pts = {} self._channel_pts = {}
self._no_updates = not receive_updates self._no_updates = not receive_updates
if sequential_updates: # Used for non-sequential updates, in order to terminate all pending tasks on disconnect.
self._updates_queue = asyncio.Queue() self._sequential_updates = sequential_updates
self._dispatching_updates_queue = asyncio.Event() self._event_handler_tasks = set()
else:
# Use a set of pending instead of a queue so we can properly
# terminate all pending updates on disconnect.
self._updates_queue = set()
self._dispatching_updates_queue = None
self._authorized = None # None = unknown, False = no, True = yes self._authorized = None # None = unknown, False = no, True = yes
@ -442,6 +427,26 @@ class TelegramBaseClient(abc.ABC):
# A place to store if channels are a megagroup or not (see `edit_admin`) # A place to store if channels are a megagroup or not (see `edit_admin`)
self._megagroup_cache = {} self._megagroup_cache = {}
# This is backported from v2 in a very ad-hoc way just to get proper update handling
self._catch_up = True
self._updates_queue = asyncio.Queue()
self._message_box = MessageBox()
# This entity cache is tailored for the messagebox and is not used for absolutely everything like _entity_cache
self._mb_entity_cache = MbEntityCache() # required for proper update handling (to know when to getDifference)
self._sender = MTProtoSender(
self.session.auth_key,
loggers=self._log,
retries=self._connection_retries,
delay=self._retry_delay,
auto_reconnect=self._auto_reconnect,
connect_timeout=self._timeout,
auth_key_callback=self._auth_key_callback,
updates_queue=self._updates_queue,
auto_reconnect_callback=self._handle_auto_reconnect
)
# endregion # endregion
# region Properties # region Properties
@ -537,6 +542,7 @@ class TelegramBaseClient(abc.ABC):
)) ))
self._updates_handle = self.loop.create_task(self._update_loop()) self._updates_handle = self.loop.create_task(self._update_loop())
self._keepalive_handle = self.loop.create_task(self._keepalive_loop())
def is_connected(self: 'TelegramClient') -> bool: def is_connected(self: 'TelegramClient') -> bool:
""" """
@ -629,13 +635,12 @@ class TelegramBaseClient(abc.ABC):
# trio's nurseries would handle this for us, but this is asyncio. # trio's nurseries would handle this for us, but this is asyncio.
# All tasks spawned in the background should properly be terminated. # All tasks spawned in the background should properly be terminated.
if self._dispatching_updates_queue is None and self._updates_queue: if self._event_handler_tasks:
for task in self._updates_queue: for task in self._event_handler_tasks:
task.cancel() task.cancel()
await asyncio.wait(self._updates_queue) await asyncio.wait(self._event_handler_tasks)
self._updates_queue.clear() self._event_handler_tasks.clear()
await self.session.close() await self.session.close()
@ -648,7 +653,8 @@ class TelegramBaseClient(abc.ABC):
""" """
await self._sender.disconnect() await self._sender.disconnect()
await helpers._cancel(self._log[__name__], await helpers._cancel(self._log[__name__],
updates_handle=self._updates_handle) updates_handle=self._updates_handle,
keepalive_handle=self._keepalive_handle)
async def _switch_dc(self: 'TelegramClient', new_dc): async def _switch_dc(self: 'TelegramClient', new_dc):
""" """
@ -845,10 +851,6 @@ class TelegramBaseClient(abc.ABC):
""" """
raise NotImplementedError raise NotImplementedError
@abc.abstractmethod
def _handle_update(self: 'TelegramClient', update):
raise NotImplementedError
@abc.abstractmethod @abc.abstractmethod
def _update_loop(self: 'TelegramClient'): def _update_loop(self: 'TelegramClient'):
raise NotImplementedError raise NotImplementedError

View File

@ -7,10 +7,12 @@ import time
import traceback import traceback
import typing import typing
import logging import logging
from collections import deque
from .. import events, utils, errors from .. import events, utils, errors
from ..events.common import EventBuilder, EventCommon from ..events.common import EventBuilder, EventCommon
from ..tl import types, functions from ..tl import types, functions
from .._updates import GapError
if typing.TYPE_CHECKING: if typing.TYPE_CHECKING:
from .telegramclient import TelegramClient from .telegramclient import TelegramClient
@ -237,106 +239,76 @@ class UpdateMethods:
await client.catch_up() await client.catch_up()
""" """
pts, date = self._state_cache[None] await self._updates_queue.put(types.UpdatesTooLong())
if not pts:
return
self.session.catching_up = True
try:
while True:
d = await self(functions.updates.GetDifferenceRequest(
pts, date, 0
))
if isinstance(d, (types.updates.DifferenceSlice,
types.updates.Difference)):
if isinstance(d, types.updates.Difference):
state = d.state
else:
state = d.intermediate_state
pts, date = state.pts, state.date
await self._handle_update(types.Updates(
users=d.users,
chats=d.chats,
date=state.date,
seq=state.seq,
updates=d.other_updates + [
types.UpdateNewMessage(m, 0, 0)
for m in d.new_messages
]
))
# TODO Implement upper limit (max_pts)
# We don't want to fetch updates we already know about.
#
# We may still get duplicates because the Difference
# contains a lot of updates and presumably only has
# the state for the last one, but at least we don't
# unnecessarily fetch too many.
#
# updates.getDifference's pts_total_limit seems to mean
# "how many pts is the request allowed to return", and
# if there is more than that, it returns "too long" (so
# there would be duplicate updates since we know about
# some). This can be used to detect collisions (i.e.
# it would return an update we have already seen).
else:
if isinstance(d, types.updates.DifferenceEmpty):
date = d.date
elif isinstance(d, types.updates.DifferenceTooLong):
pts = d.pts
break
except (ConnectionError, asyncio.CancelledError):
pass
finally:
# TODO Save new pts to session
self._state_cache._pts_date = (pts, date)
self.session.catching_up = False
# endregion # endregion
# region Private methods # region Private methods
# It is important to not make _handle_update async because we rely on async def _update_loop(self: 'TelegramClient'):
# the order that the updates arrive in to update the pts and date to try:
# be always-increasing. There is also no need to make this async. updates_to_dispatch = deque()
async def _handle_update(self: 'TelegramClient', update):
while self.is_connected():
if updates_to_dispatch:
if self._sequential_updates:
await self._dispatch_update(updates_to_dispatch.popleft())
else:
while updates_to_dispatch:
task = self.loop.create_task(self._dispatch_update(updates_to_dispatch.popleft()))
self._event_handler_tasks.add(task)
task.add_done_callback(lambda _: self._event_handler_tasks.discard(task))
continue
get_diff = self._message_box.get_difference()
if get_diff:
self._log[__name__].info('Getting difference for account updates')
diff = await self(get_diff)
updates, users, chats = self._message_box.apply_difference(diff, self._mb_entity_cache)
updates_to_dispatch.extend(await self._preprocess_updates(updates, users, chats))
continue
get_diff = self._message_box.get_channel_difference(self._mb_entity_cache)
if get_diff:
self._log[__name__].info('Getting difference for channel updates')
diff = await self(get_diff)
updates, users, chats = self._message_box.apply_channel_difference(get_diff, diff, self._mb_entity_cache)
updates_to_dispatch.extend(await self._preprocess_updates(updates, users, chats))
continue
deadline = self._message_box.check_deadlines()
try:
updates = await asyncio.wait_for(
self._updates_queue.get(),
deadline - asyncio.get_running_loop().time()
)
except asyncio.TimeoutError:
self._log[__name__].info('Timeout waiting for updates expired')
continue
processed = []
try:
users, chats = self._message_box.process_updates(updates, self._mb_entity_cache, processed)
except GapError:
continue # get(_channel)_difference will start returning requests
updates_to_dispatch.extend(await self._preprocess_updates(processed, users, chats))
except Exception:
self._log[__name__].exception('Fatal error handling updates (this is a bug in Telethon, please report it)')
async def _preprocess_updates(self, updates, users, chats):
await self.session.process_entities(update) await self.session.process_entities(update)
self._entity_cache.add(update) self._entity_cache.add(update)
if isinstance(update, (types.Updates, types.UpdatesCombined)): self._mb_entity_cache.extend(users, chats)
entities = {utils.get_peer_id(x): x for x in entities = {utils.get_peer_id(x): x
itertools.chain(update.users, update.chats)} for x in itertools.chain(users, chats)}
for u in update.updates: for u in updates:
self._process_update(u, update.updates, entities=entities) u._entities = entities
elif isinstance(update, types.UpdateShort): return updates
self._process_update(update.update, None)
else:
self._process_update(update, None)
self._state_cache.update(update) async def _keepalive_loop(self: 'TelegramClient'):
def _process_update(self: 'TelegramClient', update, others, entities=None):
update._entities = entities or {}
# This part is somewhat hot so we don't bother patching
# update with channel ID/its state. Instead we just pass
# arguments which is faster.
channel_id = self._state_cache.get_channel_id(update)
args = (update, others, channel_id, self._state_cache[channel_id])
if self._dispatching_updates_queue is None:
task = self.loop.create_task(self._dispatch_update(*args))
self._updates_queue.add(task)
task.add_done_callback(lambda _: self._updates_queue.discard(task))
else:
self._updates_queue.put_nowait(args)
if not self._dispatching_updates_queue.is_set():
self._dispatching_updates_queue.set()
self.loop.create_task(self._dispatch_queue_updates())
self._state_cache.update(update)
async def _update_loop(self: 'TelegramClient'):
# Pings' ID don't really need to be secure, just "random" # Pings' ID don't really need to be secure, just "random"
rnd = lambda: random.randrange(-2**63, 2**63) rnd = lambda: random.randrange(-2**63, 2**63)
while self.is_connected(): while self.is_connected():
@ -374,50 +346,9 @@ class UpdateMethods:
# it every minute instead. No-op if there's nothing new. # it every minute instead. No-op if there's nothing new.
await self.session.save() await self.session.save()
# We need to send some content-related request at least hourly async def _dispatch_update(self: 'TelegramClient', update):
# for Telegram to keep delivering updates, otherwise they will # TODO only used for AlbumHack, and MessageBox is not really designed for this
# just stop even if we're connected. Do so every 30 minutes. others = None
#
# TODO Call getDifference instead since it's more relevant
if time.time() - self._last_request > 30 * 60:
if not await self.is_user_authorized():
# What can be the user doing for so
# long without being logged in...?
continue
try:
await self(functions.updates.GetStateRequest())
except (ConnectionError, asyncio.CancelledError):
return
async def _dispatch_queue_updates(self: 'TelegramClient'):
while not self._updates_queue.empty():
await self._dispatch_update(*self._updates_queue.get_nowait())
self._dispatching_updates_queue.clear()
async def _dispatch_update(self: 'TelegramClient', update, others, channel_id, pts_date):
if not self._entity_cache.ensure_cached(update):
# We could add a lock to not fetch the same pts twice if we are
# already fetching it. However this does not happen in practice,
# which makes sense, because different updates have different pts.
if self._state_cache.update(update, check_only=True):
# If the update doesn't have pts, fetching won't do anything.
# For example, UpdateUserStatus or UpdateChatUserTyping.
try:
await self._get_difference(update, channel_id, pts_date)
except OSError:
pass # We were disconnected, that's okay
except errors.RPCError:
# There's a high chance the request fails because we lack
# the channel. Because these "happen sporadically" (#1428)
# we should be okay (no flood waits) even if more occur.
pass
except ValueError:
# There is a chance that GetFullChannelRequest and GetDifferenceRequest
# inside the _get_difference() function will end up with
# ValueError("Request was unsuccessful N time(s)") for whatever reasons.
pass
if not self._self_input_peer: if not self._self_input_peer:
# Some updates require our own ID, so we must make sure # Some updates require our own ID, so we must make sure
@ -523,67 +454,6 @@ class UpdateMethods:
name = getattr(callback, '__name__', repr(callback)) name = getattr(callback, '__name__', repr(callback))
self._log[__name__].exception('Unhandled exception on %s', name) self._log[__name__].exception('Unhandled exception on %s', name)
async def _get_difference(self: 'TelegramClient', update, channel_id, pts_date):
"""
Get the difference for this `channel_id` if any, then load entities.
Calls :tl:`updates.getDifference`, which fills the entities cache
(always done by `__call__`) and lets us know about the full entities.
"""
# Fetch since the last known pts/date before this update arrived,
# in order to fetch this update at full, including its entities.
self._log[__name__].debug('Getting difference for entities '
'for %r', update.__class__)
if channel_id:
# There are reports where we somehow call get channel difference
# with `InputPeerEmpty`. Check our assumptions to better debug
# this when it happens.
assert isinstance(channel_id, int), 'channel_id was {}, not int in {}'.format(type(channel_id), update)
try:
# Wrap the ID inside a peer to ensure we get a channel back.
where = await self.get_input_entity(types.PeerChannel(channel_id))
except ValueError:
# There's a high chance that this fails, since
# we are getting the difference to fetch entities.
return
if not pts_date:
# First-time, can't get difference. Get pts instead.
result = await self(functions.channels.GetFullChannelRequest(
utils.get_input_channel(where)
))
self._state_cache[channel_id] = result.full_chat.pts
return
result = await self(functions.updates.GetChannelDifferenceRequest(
channel=where,
filter=types.ChannelMessagesFilterEmpty(),
pts=pts_date, # just pts
limit=100,
force=True
))
else:
if not pts_date[0]:
# First-time, can't get difference. Get pts instead.
result = await self(functions.updates.GetStateRequest())
self._state_cache[None] = result.pts, result.date
return
result = await self(functions.updates.GetDifferenceRequest(
pts=pts_date[0],
date=pts_date[1],
qts=0
))
if isinstance(result, (types.updates.Difference,
types.updates.DifferenceSlice,
types.updates.ChannelDifference,
types.updates.ChannelDifferenceTooLong)):
update._entities.update({
utils.get_peer_id(x): x for x in
itertools.chain(result.users, result.chats)
})
async def _handle_auto_reconnect(self: 'TelegramClient'): async def _handle_auto_reconnect(self: 'TelegramClient'):
# TODO Catch-up # TODO Catch-up
# For now we make a high-level request to let Telegram # For now we make a high-level request to let Telegram

View File

@ -44,7 +44,7 @@ class MTProtoSender:
def __init__(self, auth_key, *, loggers, def __init__(self, auth_key, *, loggers,
retries=5, delay=1, auto_reconnect=True, connect_timeout=None, retries=5, delay=1, auto_reconnect=True, connect_timeout=None,
auth_key_callback=None, auth_key_callback=None,
update_callback=None, auto_reconnect_callback=None): updates_queue=None, auto_reconnect_callback=None):
self._connection = None self._connection = None
self._loggers = loggers self._loggers = loggers
self._log = loggers[__name__] self._log = loggers[__name__]
@ -53,7 +53,7 @@ class MTProtoSender:
self._auto_reconnect = auto_reconnect self._auto_reconnect = auto_reconnect
self._connect_timeout = connect_timeout self._connect_timeout = connect_timeout
self._auth_key_callback = auth_key_callback self._auth_key_callback = auth_key_callback
self._update_callback = update_callback self._updates_queue = updates_queue
self._auto_reconnect_callback = auto_reconnect_callback self._auto_reconnect_callback = auto_reconnect_callback
self._connect_lock = asyncio.Lock() self._connect_lock = asyncio.Lock()
self._ping = None self._ping = None
@ -645,8 +645,7 @@ class MTProtoSender:
return return
self._log.debug('Handling update %s', message.obj.__class__.__name__) self._log.debug('Handling update %s', message.obj.__class__.__name__)
if self._update_callback: self._updates_queue.put_nowait(message.obj)
await self._update_callback(message.obj)
async def _handle_pong(self, message): async def _handle_pong(self, message):
""" """