Fix pts from channels is different (#1160)

This commit is contained in:
Lonami Exo 2019-04-21 13:56:14 +02:00
parent 8edbfbdced
commit c1880c9191
4 changed files with 141 additions and 50 deletions

View File

@ -375,11 +375,6 @@ class AuthMethods(MessageParseMethods, UserMethods):
self._self_input_peer = utils.get_input_peer(user, allow_self=False)
self._authorized = True
# `catch_up` will getDifference from pts = 1, date = 1 (ignored)
# to fetch all updates (and obtain necessary access hashes) if
# the ``pts is None``.
self._old_pts_date = (None, None)
return user
async def send_code_request(self, phone, *, force_sms=False):
@ -436,8 +431,7 @@ class AuthMethods(MessageParseMethods, UserMethods):
self._bot = None
self._self_input_peer = None
self._authorized = False
self._old_pts_date = (None, None)
self._new_pts_date = (None, None)
self._state_cache.reset()
await self.disconnect()
self.session.delete()

View File

@ -13,6 +13,7 @@ from ..sessions import Session, SQLiteSession, MemorySession
from ..tl import TLObject, functions, types
from ..tl.alltlobjects import LAYER
from ..entitycache import EntityCache
from ..statecache import StateCache
DEFAULT_DC_ID = 4
DEFAULT_IPV4_IP = '149.154.167.51'
@ -306,13 +307,8 @@ class TelegramBaseClient(abc.ABC):
self._authorized = None # None = unknown, False = no, True = yes
# Update state (for catching up after a disconnection)
#
# We only care about the pts and the date. By using a tuple which
# is lightweight and immutable we can easily copy them around to
# each update in case they need to fetch missing entities.
state = self.session.get_update_state(0)
self._old_pts_date = (state.pts, state.date) if state else (None, None)
self._new_pts_date = (None, None)
# TODO Get state from channels too
self._state_cache = StateCache(self.session.get_update_state(0))
# Some further state for subclasses
self._event_builders = []
@ -395,15 +391,14 @@ class TelegramBaseClient(abc.ABC):
async def _disconnect_coro(self):
await self._disconnect()
pts, date = self._new_pts_date
if pts:
self.session.set_update_state(0, types.updates.State(
pts=pts,
qts=0,
date=date or datetime.now(),
seq=0,
unread_count=0
))
pts, date = self._state_cache[None]
self.session.set_update_state(0, types.updates.State(
pts=pts,
qts=0,
date=date,
seq=0,
unread_count=0
))
self.session.close()

View File

@ -8,6 +8,7 @@ from .users import UserMethods
from .. import events, utils, errors
from ..tl import types, functions
from ..events.common import EventCommon
from ..statecache import StateCache
class UpdateMethods(UserMethods):
@ -135,14 +136,7 @@ class UpdateMethods(UserMethods):
This can also be used to forcibly fetch new updates if there are any.
"""
# TODO Since which state should we catch up?
if all(self._new_pts_date):
pts, date = self._new_pts_date
elif all(self._old_pts_date):
pts, date = self._old_pts_date
else:
return
pts, date = self._state_cache[None]
self.session.catching_up = True
try:
while True:
@ -192,7 +186,7 @@ class UpdateMethods(UserMethods):
pass
finally:
# TODO Save new pts to session
self._new_pts_date = (pts, date)
self._state_cache._pts_date = (pts, date)
self.session.catching_up = False
# endregion
@ -211,19 +205,15 @@ class UpdateMethods(UserMethods):
itertools.chain(update.users, update.chats)}
for u in update.updates:
self._process_update(u, entities)
self._new_pts_date = (self._new_pts_date[0], update.date)
elif isinstance(update, types.UpdateShort):
self._process_update(update.update)
self._new_pts_date = (self._new_pts_date[0], update.date)
else:
self._process_update(update)
# TODO Should this be done before or after?
self._update_pts_date(update)
self._state_cache.update(update)
def _process_update(self, update, entities=None):
update._pts_date = self._new_pts_date
update._pts_date = self._state_cache[StateCache.get_channel_id(update)]
update._entities = entities or {}
if self._updates_queue is None:
self._loop.create_task(self._dispatch_update(update))
@ -233,17 +223,7 @@ class UpdateMethods(UserMethods):
self._dispatching_updates_queue.set()
self._loop.create_task(self._dispatch_queue_updates())
self._update_pts_date(update)
def _update_pts_date(self, update):
pts, date = self._new_pts_date
if getattr(update, 'pts', None):
pts = update.pts
if getattr(update, 'date', None):
date = update.date
self._new_pts_date = (pts, date)
self._state_cache.update(update)
async def _update_loop(self):
# Pings' ID don't really need to be secure, just "random"
@ -416,7 +396,12 @@ class EventBuilderDict:
"""
# Fetch since the last known pts/date before this update arrived,
# in order to fetch this update at full.
pts, date = self.update._pts_date
pts_date = self.update._pts_date
if not isinstance(pts_date, tuple):
# TODO Handle channels, and handle this more nicely
return
pts, date = pts_date
if not pts:
return

117
telethon/statecache.py Normal file
View File

@ -0,0 +1,117 @@
import datetime
from .tl import types
class StateCache:
"""
In-memory update state cache, defaultdict-like behaviour.
"""
def __init__(self, initial):
# We only care about the pts and the date. By using a tuple which
# is lightweight and immutable we can easily copy them around to
# each update in case they need to fetch missing entities.
if initial:
self._pts_date = initial.pts, initial.date
else:
self._pts_date = 1, datetime.datetime.now()
def reset(self):
self.__dict__.clear()
self._pts_date = (1, 1)
# TODO Call this when receiving responses too...?
def update(
self,
update,
*,
channel_id=None,
has_pts=(
types.UpdateNewMessage,
types.UpdateDeleteMessages,
types.UpdateReadHistoryInbox,
types.UpdateReadHistoryOutbox,
types.UpdateWebPage,
types.UpdateReadMessagesContents,
types.UpdateEditMessage,
types.updates.State,
types.updates.DifferenceTooLong,
types.UpdateShortMessage,
types.UpdateShortChatMessage,
types.UpdateShortSentMessage
),
has_date=(
types.UpdateUserPhoto,
types.UpdateEncryption,
types.UpdateEncryptedMessagesRead,
types.UpdateChatParticipantAdd,
types.updates.DifferenceEmpty,
types.UpdateShortMessage,
types.UpdateShortChatMessage,
types.UpdateShort,
types.UpdatesCombined,
types.Updates,
types.UpdateShortSentMessage,
),
has_channel_pts=(
types.UpdateChannelTooLong,
types.UpdateNewChannelMessage,
types.UpdateDeleteChannelMessages,
types.UpdateEditChannelMessage,
types.UpdateChannelWebPage,
types.updates.ChannelDifferenceEmpty,
types.updates.ChannelDifferenceTooLong,
types.updates.ChannelDifference
)
):
"""
Update the state with the given update.
"""
has_pts = isinstance(update, has_pts)
has_date = isinstance(update, has_date)
has_channel_pts = isinstance(update, has_channel_pts)
if has_pts and has_date:
self._pts_date = update.pts, update.date
elif has_pts:
self._pts_date = update.pts, self._pts_date[1]
elif has_date:
self._pts_date = self._pts_date[0], update.date
if has_channel_pts:
if channel_id is None:
channel_id = self.get_channel_id(update)
if channel_id is None:
pass # TODO log, but shouldn't happen
else:
self.__dict__[channel_id] = update.pts
@staticmethod
def get_channel_id(
update,
has_channel_id=(
types.UpdateChannelTooLong,
types.UpdateDeleteChannelMessages,
types.UpdateChannelWebPage
),
has_message=(
types.UpdateNewChannelMessage,
types.UpdateEditChannelMessage
)
):
# Will only fail for *difference, where channel_id is known
if isinstance(update, has_channel_id):
return update.channel_id
elif isinstance(update, has_message):
return update.message.to_id.channel_id
else:
return None
def __getitem__(self, item):
"""
Gets the corresponding ``(pts, date)`` for the given ID or peer,
"""
if item is None:
return self._pts_date
else:
return self.__dict__.get(item, 1)