Make use of the new MessageBox

This commit is contained in:
Lonami Exo 2022-05-13 13:17:16 +02:00
parent b5bfe5d9a1
commit db09a92bc5
3 changed files with 101 additions and 230 deletions

View File

@ -15,6 +15,7 @@ from ..network import MTProtoSender, Connection, ConnectionTcpFull, TcpMTProxy
from ..sessions import Session, SQLiteSession, MemorySession
from ..tl import functions, types
from ..tl.alltlobjects import LAYER
from .._updates import MessageBox, EntityCache as MbEntityCache
DEFAULT_DC_ID = 2
DEFAULT_IPV4_IP = '149.154.167.51'
@ -376,18 +377,6 @@ class TelegramBaseClient(abc.ABC):
proxy=init_proxy
)
self._sender = MTProtoSender(
self.session.auth_key,
loggers=self._log,
retries=self._connection_retries,
delay=self._retry_delay,
auto_reconnect=self._auto_reconnect,
connect_timeout=self._timeout,
auth_key_callback=self._auth_key_callback,
update_callback=self._handle_update,
auto_reconnect_callback=self._handle_auto_reconnect
)
# Remember flood-waited requests to avoid making them again
self._flood_waited_requests = {}
@ -396,18 +385,14 @@ class TelegramBaseClient(abc.ABC):
self._borrow_sender_lock = asyncio.Lock()
self._updates_handle = None
self._keepalive_handle = None
self._last_request = time.time()
self._channel_pts = {}
self._no_updates = not receive_updates
if sequential_updates:
self._updates_queue = asyncio.Queue()
self._dispatching_updates_queue = asyncio.Event()
else:
# Use a set of pending instead of a queue so we can properly
# terminate all pending updates on disconnect.
self._updates_queue = set()
self._dispatching_updates_queue = None
# Used for non-sequential updates, in order to terminate all pending tasks on disconnect.
self._sequential_updates = sequential_updates
self._event_handler_tasks = set()
self._authorized = None # None = unknown, False = no, True = yes
@ -442,6 +427,26 @@ class TelegramBaseClient(abc.ABC):
# A place to store if channels are a megagroup or not (see `edit_admin`)
self._megagroup_cache = {}
# This is backported from v2 in a very ad-hoc way just to get proper update handling
self._catch_up = True
self._updates_queue = asyncio.Queue()
self._message_box = MessageBox()
# This entity cache is tailored for the messagebox and is not used for absolutely everything like _entity_cache
self._mb_entity_cache = MbEntityCache() # required for proper update handling (to know when to getDifference)
self._sender = MTProtoSender(
self.session.auth_key,
loggers=self._log,
retries=self._connection_retries,
delay=self._retry_delay,
auto_reconnect=self._auto_reconnect,
connect_timeout=self._timeout,
auth_key_callback=self._auth_key_callback,
updates_queue=self._updates_queue,
auto_reconnect_callback=self._handle_auto_reconnect
)
# endregion
# region Properties
@ -537,6 +542,7 @@ class TelegramBaseClient(abc.ABC):
))
self._updates_handle = self.loop.create_task(self._update_loop())
self._keepalive_handle = self.loop.create_task(self._keepalive_loop())
def is_connected(self: 'TelegramClient') -> bool:
"""
@ -629,13 +635,12 @@ class TelegramBaseClient(abc.ABC):
# trio's nurseries would handle this for us, but this is asyncio.
# All tasks spawned in the background should properly be terminated.
if self._dispatching_updates_queue is None and self._updates_queue:
for task in self._updates_queue:
if self._event_handler_tasks:
for task in self._event_handler_tasks:
task.cancel()
await asyncio.wait(self._updates_queue)
self._updates_queue.clear()
await asyncio.wait(self._event_handler_tasks)
self._event_handler_tasks.clear()
await self.session.close()
@ -648,7 +653,8 @@ class TelegramBaseClient(abc.ABC):
"""
await self._sender.disconnect()
await helpers._cancel(self._log[__name__],
updates_handle=self._updates_handle)
updates_handle=self._updates_handle,
keepalive_handle=self._keepalive_handle)
async def _switch_dc(self: 'TelegramClient', new_dc):
"""
@ -845,10 +851,6 @@ class TelegramBaseClient(abc.ABC):
"""
raise NotImplementedError
@abc.abstractmethod
def _handle_update(self: 'TelegramClient', update):
raise NotImplementedError
@abc.abstractmethod
def _update_loop(self: 'TelegramClient'):
raise NotImplementedError

View File

@ -7,10 +7,12 @@ import time
import traceback
import typing
import logging
from collections import deque
from .. import events, utils, errors
from ..events.common import EventBuilder, EventCommon
from ..tl import types, functions
from .._updates import GapError
if typing.TYPE_CHECKING:
from .telegramclient import TelegramClient
@ -237,106 +239,76 @@ class UpdateMethods:
await client.catch_up()
"""
pts, date = self._state_cache[None]
if not pts:
return
self.session.catching_up = True
try:
while True:
d = await self(functions.updates.GetDifferenceRequest(
pts, date, 0
))
if isinstance(d, (types.updates.DifferenceSlice,
types.updates.Difference)):
if isinstance(d, types.updates.Difference):
state = d.state
else:
state = d.intermediate_state
pts, date = state.pts, state.date
await self._handle_update(types.Updates(
users=d.users,
chats=d.chats,
date=state.date,
seq=state.seq,
updates=d.other_updates + [
types.UpdateNewMessage(m, 0, 0)
for m in d.new_messages
]
))
# TODO Implement upper limit (max_pts)
# We don't want to fetch updates we already know about.
#
# We may still get duplicates because the Difference
# contains a lot of updates and presumably only has
# the state for the last one, but at least we don't
# unnecessarily fetch too many.
#
# updates.getDifference's pts_total_limit seems to mean
# "how many pts is the request allowed to return", and
# if there is more than that, it returns "too long" (so
# there would be duplicate updates since we know about
# some). This can be used to detect collisions (i.e.
# it would return an update we have already seen).
else:
if isinstance(d, types.updates.DifferenceEmpty):
date = d.date
elif isinstance(d, types.updates.DifferenceTooLong):
pts = d.pts
break
except (ConnectionError, asyncio.CancelledError):
pass
finally:
# TODO Save new pts to session
self._state_cache._pts_date = (pts, date)
self.session.catching_up = False
await self._updates_queue.put(types.UpdatesTooLong())
# endregion
# region Private methods
# It is important to not make _handle_update async because we rely on
# the order that the updates arrive in to update the pts and date to
# be always-increasing. There is also no need to make this async.
async def _handle_update(self: 'TelegramClient', update):
async def _update_loop(self: 'TelegramClient'):
try:
updates_to_dispatch = deque()
while self.is_connected():
if updates_to_dispatch:
if self._sequential_updates:
await self._dispatch_update(updates_to_dispatch.popleft())
else:
while updates_to_dispatch:
task = self.loop.create_task(self._dispatch_update(updates_to_dispatch.popleft()))
self._event_handler_tasks.add(task)
task.add_done_callback(lambda _: self._event_handler_tasks.discard(task))
continue
get_diff = self._message_box.get_difference()
if get_diff:
self._log[__name__].info('Getting difference for account updates')
diff = await self(get_diff)
updates, users, chats = self._message_box.apply_difference(diff, self._mb_entity_cache)
updates_to_dispatch.extend(await self._preprocess_updates(updates, users, chats))
continue
get_diff = self._message_box.get_channel_difference(self._mb_entity_cache)
if get_diff:
self._log[__name__].info('Getting difference for channel updates')
diff = await self(get_diff)
updates, users, chats = self._message_box.apply_channel_difference(get_diff, diff, self._mb_entity_cache)
updates_to_dispatch.extend(await self._preprocess_updates(updates, users, chats))
continue
deadline = self._message_box.check_deadlines()
try:
updates = await asyncio.wait_for(
self._updates_queue.get(),
deadline - asyncio.get_running_loop().time()
)
except asyncio.TimeoutError:
self._log[__name__].info('Timeout waiting for updates expired')
continue
processed = []
try:
users, chats = self._message_box.process_updates(updates, self._mb_entity_cache, processed)
except GapError:
continue # get(_channel)_difference will start returning requests
updates_to_dispatch.extend(await self._preprocess_updates(processed, users, chats))
except Exception:
self._log[__name__].exception('Fatal error handling updates (this is a bug in Telethon, please report it)')
async def _preprocess_updates(self, updates, users, chats):
await self.session.process_entities(update)
self._entity_cache.add(update)
if isinstance(update, (types.Updates, types.UpdatesCombined)):
entities = {utils.get_peer_id(x): x for x in
itertools.chain(update.users, update.chats)}
for u in update.updates:
self._process_update(u, update.updates, entities=entities)
elif isinstance(update, types.UpdateShort):
self._process_update(update.update, None)
else:
self._process_update(update, None)
self._mb_entity_cache.extend(users, chats)
entities = {utils.get_peer_id(x): x
for x in itertools.chain(users, chats)}
for u in updates:
u._entities = entities
return updates
self._state_cache.update(update)
def _process_update(self: 'TelegramClient', update, others, entities=None):
update._entities = entities or {}
# This part is somewhat hot so we don't bother patching
# update with channel ID/its state. Instead we just pass
# arguments which is faster.
channel_id = self._state_cache.get_channel_id(update)
args = (update, others, channel_id, self._state_cache[channel_id])
if self._dispatching_updates_queue is None:
task = self.loop.create_task(self._dispatch_update(*args))
self._updates_queue.add(task)
task.add_done_callback(lambda _: self._updates_queue.discard(task))
else:
self._updates_queue.put_nowait(args)
if not self._dispatching_updates_queue.is_set():
self._dispatching_updates_queue.set()
self.loop.create_task(self._dispatch_queue_updates())
self._state_cache.update(update)
async def _update_loop(self: 'TelegramClient'):
async def _keepalive_loop(self: 'TelegramClient'):
# Pings' ID don't really need to be secure, just "random"
rnd = lambda: random.randrange(-2**63, 2**63)
while self.is_connected():
@ -374,50 +346,9 @@ class UpdateMethods:
# it every minute instead. No-op if there's nothing new.
await self.session.save()
# We need to send some content-related request at least hourly
# for Telegram to keep delivering updates, otherwise they will
# just stop even if we're connected. Do so every 30 minutes.
#
# TODO Call getDifference instead since it's more relevant
if time.time() - self._last_request > 30 * 60:
if not await self.is_user_authorized():
# What can be the user doing for so
# long without being logged in...?
continue
try:
await self(functions.updates.GetStateRequest())
except (ConnectionError, asyncio.CancelledError):
return
async def _dispatch_queue_updates(self: 'TelegramClient'):
while not self._updates_queue.empty():
await self._dispatch_update(*self._updates_queue.get_nowait())
self._dispatching_updates_queue.clear()
async def _dispatch_update(self: 'TelegramClient', update, others, channel_id, pts_date):
if not self._entity_cache.ensure_cached(update):
# We could add a lock to not fetch the same pts twice if we are
# already fetching it. However this does not happen in practice,
# which makes sense, because different updates have different pts.
if self._state_cache.update(update, check_only=True):
# If the update doesn't have pts, fetching won't do anything.
# For example, UpdateUserStatus or UpdateChatUserTyping.
try:
await self._get_difference(update, channel_id, pts_date)
except OSError:
pass # We were disconnected, that's okay
except errors.RPCError:
# There's a high chance the request fails because we lack
# the channel. Because these "happen sporadically" (#1428)
# we should be okay (no flood waits) even if more occur.
pass
except ValueError:
# There is a chance that GetFullChannelRequest and GetDifferenceRequest
# inside the _get_difference() function will end up with
# ValueError("Request was unsuccessful N time(s)") for whatever reasons.
pass
async def _dispatch_update(self: 'TelegramClient', update):
# TODO only used for AlbumHack, and MessageBox is not really designed for this
others = None
if not self._self_input_peer:
# Some updates require our own ID, so we must make sure
@ -523,67 +454,6 @@ class UpdateMethods:
name = getattr(callback, '__name__', repr(callback))
self._log[__name__].exception('Unhandled exception on %s', name)
async def _get_difference(self: 'TelegramClient', update, channel_id, pts_date):
"""
Get the difference for this `channel_id` if any, then load entities.
Calls :tl:`updates.getDifference`, which fills the entities cache
(always done by `__call__`) and lets us know about the full entities.
"""
# Fetch since the last known pts/date before this update arrived,
# in order to fetch this update at full, including its entities.
self._log[__name__].debug('Getting difference for entities '
'for %r', update.__class__)
if channel_id:
# There are reports where we somehow call get channel difference
# with `InputPeerEmpty`. Check our assumptions to better debug
# this when it happens.
assert isinstance(channel_id, int), 'channel_id was {}, not int in {}'.format(type(channel_id), update)
try:
# Wrap the ID inside a peer to ensure we get a channel back.
where = await self.get_input_entity(types.PeerChannel(channel_id))
except ValueError:
# There's a high chance that this fails, since
# we are getting the difference to fetch entities.
return
if not pts_date:
# First-time, can't get difference. Get pts instead.
result = await self(functions.channels.GetFullChannelRequest(
utils.get_input_channel(where)
))
self._state_cache[channel_id] = result.full_chat.pts
return
result = await self(functions.updates.GetChannelDifferenceRequest(
channel=where,
filter=types.ChannelMessagesFilterEmpty(),
pts=pts_date, # just pts
limit=100,
force=True
))
else:
if not pts_date[0]:
# First-time, can't get difference. Get pts instead.
result = await self(functions.updates.GetStateRequest())
self._state_cache[None] = result.pts, result.date
return
result = await self(functions.updates.GetDifferenceRequest(
pts=pts_date[0],
date=pts_date[1],
qts=0
))
if isinstance(result, (types.updates.Difference,
types.updates.DifferenceSlice,
types.updates.ChannelDifference,
types.updates.ChannelDifferenceTooLong)):
update._entities.update({
utils.get_peer_id(x): x for x in
itertools.chain(result.users, result.chats)
})
async def _handle_auto_reconnect(self: 'TelegramClient'):
# TODO Catch-up
# For now we make a high-level request to let Telegram

View File

@ -44,7 +44,7 @@ class MTProtoSender:
def __init__(self, auth_key, *, loggers,
retries=5, delay=1, auto_reconnect=True, connect_timeout=None,
auth_key_callback=None,
update_callback=None, auto_reconnect_callback=None):
updates_queue=None, auto_reconnect_callback=None):
self._connection = None
self._loggers = loggers
self._log = loggers[__name__]
@ -53,7 +53,7 @@ class MTProtoSender:
self._auto_reconnect = auto_reconnect
self._connect_timeout = connect_timeout
self._auth_key_callback = auth_key_callback
self._update_callback = update_callback
self._updates_queue = updates_queue
self._auto_reconnect_callback = auto_reconnect_callback
self._connect_lock = asyncio.Lock()
self._ping = None
@ -645,8 +645,7 @@ class MTProtoSender:
return
self._log.debug('Handling update %s', message.obj.__class__.__name__)
if self._update_callback:
await self._update_callback(message.obj)
self._updates_queue.put_nowait(message.obj)
async def _handle_pong(self, message):
"""