Save pts and date in a tuple for immutability

This way it is easy and cheap to copy the two required values
to all incoming updates in case we need to getDifference since
the previous pts/date to fetch entities.

This is still a work in progress.
This commit is contained in:
Lonami Exo 2019-04-10 21:09:15 +04:00
parent bec0fa414e
commit 9965cda968
4 changed files with 68 additions and 76 deletions

View File

@ -375,11 +375,10 @@ class AuthMethods(MessageParseMethods, UserMethods):
self._self_input_peer = utils.get_input_peer(user, allow_self=False) self._self_input_peer = utils.get_input_peer(user, allow_self=False)
self._authorized = True self._authorized = True
# By setting state.pts = 1 after logging in, the user or bot can # `catch_up` will getDifference from pts = 1, date = 1 (ignored)
# `catch_up` on all updates (and obtain necessary access hashes) # to fetch all updates (and obtain necessary access hashes) if
# if they desire. The date parameter is ignored when pts = 1. # the ``pts is None``.
self._old_state = types.updates.State( self._old_pts_date = (None, None)
1, 0, datetime.datetime.now(tz=datetime.timezone.utc), 0, 0)
return user return user
@ -437,8 +436,8 @@ class AuthMethods(MessageParseMethods, UserMethods):
self._bot = None self._bot = None
self._self_input_peer = None self._self_input_peer = None
self._authorized = False self._authorized = False
self._old_state = None self._old_pts_date = (None, None)
self._new_state = None self._new_pts_date = (None, None)
await self.disconnect() await self.disconnect()
self.session.delete() self.session.delete()

View File

@ -306,15 +306,13 @@ class TelegramBaseClient(abc.ABC):
self._authorized = None # None = unknown, False = no, True = yes self._authorized = None # None = unknown, False = no, True = yes
# Update state (for catching up after a disconnection) # Update state (for catching up after a disconnection)
self._old_state = self.session.get_update_state(0)
self._new_state = None
# If we catch up, while we don't get disconnected,
# the old state will be the same as the new one.
# #
# If we do get disconnected, then the old and new # We only care about the pts and the date. By using a tuple which
# state may differ. # is lightweight and immutable we can easily copy them around to
self._old_state_is_new = False # each update in case they need to fetch missing entities.
state = self.session.get_update_state(0)
self._old_pts_date = state.pts, state.date
self._new_pts_date = (None, None)
# Some further state for subclasses # Some further state for subclasses
self._event_builders = [] self._event_builders = []
@ -397,11 +395,15 @@ class TelegramBaseClient(abc.ABC):
async def _disconnect_coro(self): async def _disconnect_coro(self):
await self._disconnect() await self._disconnect()
# If we disconnect, the old state is the last one we are aware of pts, date = self._new_pts_date
self._old_state_is_new = True if pts:
self.session.set_update_state(0, types.updates.State(
if self._new_state: pts=pts,
self.session.set_update_state(0, self._new_state) qts=0,
date=date or datetime.now(),
seq=0,
unread_count=0
))
self.session.close() self.session.close()

View File

@ -135,23 +135,19 @@ class UpdateMethods(UserMethods):
This can also be used to forcibly fetch new updates if there are any. This can also be used to forcibly fetch new updates if there are any.
""" """
state = self._new_state if self._old_state_is_new else self._old_state # TODO Since which state should we catch up?
if not self._old_state_is_new and self._new_state: if all(self._new_pts_date):
max_pts = self._new_state.pts pts, date = self._new_pts_date
elif all(self._old_pts_date):
pts, date = self._new_pts_date
else: else:
max_pts = float('inf') return
# No known state -> catch up since the beginning (date is ignored).
# Note: pts = 0 is invalid (and so is no date/unix timestamp = 0).
if not state:
state = types.updates.State(
1, 0, datetime.datetime.now(tz=datetime.timezone.utc), 0, 0)
self.session.catching_up = True self.session.catching_up = True
try: try:
while True: while True:
d = await self(functions.updates.GetDifferenceRequest( d = await self(functions.updates.GetDifferenceRequest(
state.pts, state.date, state.qts pts, date, 0
)) ))
if isinstance(d, (types.updates.DifferenceSlice, if isinstance(d, (types.updates.DifferenceSlice,
types.updates.Difference)): types.updates.Difference)):
@ -160,7 +156,8 @@ class UpdateMethods(UserMethods):
else: else:
state = d.intermediate_state state = d.intermediate_state
await self._handle_update(types.Updates( pts, date = state.pts, state.date
self._handle_update(types.Updates(
users=d.users, users=d.users,
chats=d.chats, chats=d.chats,
date=state.date, date=state.date,
@ -171,6 +168,7 @@ class UpdateMethods(UserMethods):
] ]
)) ))
# TODO Implement upper limit (max_pts)
# We don't want to fetch updates we already know about. # We don't want to fetch updates we already know about.
# #
# We may still get duplicates because the Difference # We may still get duplicates because the Difference
@ -184,29 +182,27 @@ class UpdateMethods(UserMethods):
# there would be duplicate updates since we know about # there would be duplicate updates since we know about
# some). This can be used to detect collisions (i.e. # some). This can be used to detect collisions (i.e.
# it would return an update we have already seen). # it would return an update we have already seen).
if state.pts >= max_pts:
break
else: else:
if isinstance(d, types.updates.DifferenceEmpty): if isinstance(d, types.updates.DifferenceEmpty):
state.date = d.date date = d.date
state.seq = d.seq
elif isinstance(d, types.updates.DifferenceTooLong): elif isinstance(d, types.updates.DifferenceTooLong):
state.pts = d.pts pts = d.pts
break break
except (ConnectionError, asyncio.CancelledError): except (ConnectionError, asyncio.CancelledError):
pass pass
finally: finally:
self._old_state = None # TODO Save new pts to session
self._new_state = state self._new_pts_date = (pts, date)
self._old_state_is_new = True
self.session.set_update_state(0, state)
self.session.catching_up = False self.session.catching_up = False
# endregion # endregion
# region Private methods # region Private methods
async def _handle_update(self, update): # It is important to not make _handle_update async because we rely on
# the order that the updates arrive in to update the pts and date to
# be always-increasing. There is also no need to make this async.
def _handle_update(self, update):
self.session.process_entities(update) self.session.process_entities(update)
self._entity_cache.add(update) self._entity_cache.add(update)
@ -214,40 +210,39 @@ class UpdateMethods(UserMethods):
entities = {utils.get_peer_id(x): x for x in entities = {utils.get_peer_id(x): x for x in
itertools.chain(update.users, update.chats)} itertools.chain(update.users, update.chats)}
for u in update.updates: for u in update.updates:
u._entities = entities self._process_update(u, entities)
await self._handle_update(u)
self._new_pts_date = (self._new_pts_date[0], update.date)
elif isinstance(update, types.UpdateShort): elif isinstance(update, types.UpdateShort):
await self._handle_update(update.update) self._process_update(update.update)
self._new_pts_date = (self._new_pts_date[0], update.date)
else: else:
update._entities = getattr(update, '_entities', {}) self._process_update(update)
if self._updates_queue is None:
self._loop.create_task(self._dispatch_update(update))
else:
self._updates_queue.put_nowait(update)
if not self._dispatching_updates_queue.is_set():
self._dispatching_updates_queue.set()
self._loop.create_task(self._dispatch_queue_updates())
# TODO make use of need_diff # TODO Should this be done before or after?
need_diff = False self._update_pts_date(update)
def _process_update(self, update, entities=None):
update._entities = entities or {}
if self._updates_queue is None:
self._loop.create_task(self._dispatch_update(update))
else:
self._updates_queue.put_nowait(update)
if not self._dispatching_updates_queue.is_set():
self._dispatching_updates_queue.set()
self._loop.create_task(self._dispatch_queue_updates())
self._update_pts_date(update)
def _update_pts_date(self, update):
pts, date = self._new_pts_date
if getattr(update, 'pts', None): if getattr(update, 'pts', None):
if not self._new_state: pts = update.pts
self._new_state = types.updates.State(
update.pts,
0,
getattr(update, 'date', datetime.datetime.now(tz=datetime.timezone.utc)),
getattr(update, 'seq', 0),
0
)
else:
if self._new_state.pts and (update.pts - self._new_state.pts) > 1:
need_diff = True
self._new_state.pts = update.pts if getattr(update, 'date', None):
if hasattr(update, 'date'): date = update.date
self._new_state.date = update.date
if hasattr(update, 'seq'): self._new_pts_date = (pts, date)
self._new_state.seq = update.seq
async def _update_loop(self): async def _update_loop(self):
# Pings' ID don't really need to be secure, just "random" # Pings' ID don't really need to be secure, just "random"
@ -368,10 +363,6 @@ class UpdateMethods(UserMethods):
# If a disconnection occurs, the old known state will be # If a disconnection occurs, the old known state will be
# the latest one we were aware of, so we can catch up since # the latest one we were aware of, so we can catch up since
# the most recent state we were aware of. # the most recent state we were aware of.
# TODO Ideally we set _old_state = _new_state *on* disconnect,
# not *after* we managed to reconnect since perhaps an update
# arrives just before we can get started.
self._old_state_is_new = True
await self.catch_up() await self.catch_up()
self._log[__name__].info('Successfully fetched missed updates') self._log[__name__].info('Successfully fetched missed updates')

View File

@ -547,7 +547,7 @@ class MTProtoSender:
self._log.debug('Handling update {}' self._log.debug('Handling update {}'
.format(message.obj.__class__.__name__)) .format(message.obj.__class__.__name__))
if self._update_callback: if self._update_callback:
await self._update_callback(message.obj) self._update_callback(message.obj)
async def _handle_pong(self, message): async def _handle_pong(self, message):
""" """