mirror of
https://github.com/LonamiWebs/Telethon.git
synced 2025-07-30 17:59:55 +03:00
Fix catch_up
updates.py: - Fix catch_up getting stuck in infinite loop by saving qts - Save new state to session after catch_up - Save initial state immediately if catch_up was ran first time - If DifferenceTooLong was received, continue fetching with new pts instead of exiting - Catch up channels - Implement pts limit (pts_total_limit in chats and limit in channels) - Check if update was already processed before processing telegrambaseclient.py, sessions: - Get saved state from channels on startup and save channels state statecache.py: - Create separate variable _store in StateCache, because __dict__ also stores _logger variable - Move has_pts, has_qts, has_date, has_channel_pts outside of function arguments Fixes #3041
This commit is contained in:
parent
3d350c6087
commit
34bb2b8fc3
|
@ -6,6 +6,7 @@ import logging
|
||||||
import platform
|
import platform
|
||||||
import time
|
import time
|
||||||
import typing
|
import typing
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
from .. import version, helpers, __name__ as __base_name__
|
from .. import version, helpers, __name__ as __base_name__
|
||||||
from ..crypto import rsa
|
from ..crypto import rsa
|
||||||
|
@ -401,9 +402,10 @@ class TelegramBaseClient(abc.ABC):
|
||||||
self._authorized = None # None = unknown, False = no, True = yes
|
self._authorized = None # None = unknown, False = no, True = yes
|
||||||
|
|
||||||
# Update state (for catching up after a disconnection)
|
# Update state (for catching up after a disconnection)
|
||||||
# TODO Get state from channels too
|
|
||||||
self._state_cache = StateCache(
|
self._state_cache = StateCache(
|
||||||
self.session.get_update_state(0), self._log)
|
self.session.get_update_state(0), self._log)
|
||||||
|
for k, v in self.session.get_channel_pts().items():
|
||||||
|
self._state_cache[k] = v
|
||||||
|
|
||||||
# Some further state for subclasses
|
# Some further state for subclasses
|
||||||
self._event_builders = []
|
self._event_builders = []
|
||||||
|
@ -630,15 +632,23 @@ class TelegramBaseClient(abc.ABC):
|
||||||
await asyncio.wait(self._updates_queue)
|
await asyncio.wait(self._updates_queue)
|
||||||
self._updates_queue.clear()
|
self._updates_queue.clear()
|
||||||
|
|
||||||
pts, date = self._state_cache[None]
|
pts, qts, date = self._state_cache[None]
|
||||||
if pts and date:
|
if pts and date:
|
||||||
self.session.set_update_state(0, types.updates.State(
|
self.session.set_update_state(0, types.updates.State(
|
||||||
pts=pts,
|
pts=pts,
|
||||||
qts=0,
|
qts=qts,
|
||||||
date=date,
|
date=date,
|
||||||
seq=0,
|
seq=0,
|
||||||
unread_count=0
|
unread_count=0
|
||||||
))
|
))
|
||||||
|
for channel_id, pts in self._state_cache.get_channel_pts().items():
|
||||||
|
self.session.set_update_state(channel_id, types.updates.State(
|
||||||
|
pts=pts,
|
||||||
|
qts=0,
|
||||||
|
date=datetime.fromtimestamp(0),
|
||||||
|
seq=0,
|
||||||
|
unread_count=0
|
||||||
|
))
|
||||||
|
|
||||||
self.session.close()
|
self.session.close()
|
||||||
|
|
||||||
|
|
|
@ -7,10 +7,12 @@ import time
|
||||||
import traceback
|
import traceback
|
||||||
import typing
|
import typing
|
||||||
import logging
|
import logging
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
from .. import events, utils, errors
|
from .. import events, utils, errors
|
||||||
from ..events.common import EventBuilder, EventCommon
|
from ..events.common import EventBuilder, EventCommon
|
||||||
from ..tl import types, functions
|
from ..tl import types, functions
|
||||||
|
from ..tl.types import UpdateChannelTooLong
|
||||||
|
|
||||||
if typing.TYPE_CHECKING:
|
if typing.TYPE_CHECKING:
|
||||||
from .telegramclient import TelegramClient
|
from .telegramclient import TelegramClient
|
||||||
|
@ -211,7 +213,45 @@ class UpdateMethods:
|
||||||
"""
|
"""
|
||||||
return [(callback, event) for event, callback in self._event_builders]
|
return [(callback, event) for event, callback in self._event_builders]
|
||||||
|
|
||||||
async def catch_up(self: 'TelegramClient'):
|
async def _catch_up_channel(self: 'TelegramClient', channel_id: int, pts: int, limit: int = None):
|
||||||
|
if self._state_cache[channel_id]:
|
||||||
|
pts = self._state_cache[channel_id]
|
||||||
|
|
||||||
|
if not pts:
|
||||||
|
# First-time, can't get difference. Get pts instead.
|
||||||
|
result = await self(functions.channels.GetFullChannelRequest(channel_id))
|
||||||
|
pts = self._state_cache[channel_id] = result.full_chat.pts
|
||||||
|
self.session.set_update_state(channel_id, types.updates.State(pts, 0, datetime.fromtimestamp(0), 0, 0))
|
||||||
|
return
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
d = await self(functions.updates.GetChannelDifferenceRequest(
|
||||||
|
channel=channel_id,
|
||||||
|
filter=types.ChannelMessagesFilterEmpty(),
|
||||||
|
pts=pts,
|
||||||
|
limit=limit or 100000 # If the limit isn't set, fetch all updates that we can
|
||||||
|
))
|
||||||
|
if isinstance(d, types.updates.ChannelDifference):
|
||||||
|
pts = d.pts
|
||||||
|
self._handle_update(types.Updates(
|
||||||
|
users=d.users,
|
||||||
|
chats=d.chats,
|
||||||
|
date=None,
|
||||||
|
seq=0,
|
||||||
|
updates=d.other_updates + [
|
||||||
|
types.UpdateNewChannelMessage(m, 0, 0)
|
||||||
|
for m in d.new_messages
|
||||||
|
]
|
||||||
|
))
|
||||||
|
elif isinstance(d, (types.updates.ChannelDifferenceTooLong,
|
||||||
|
types.updates.ChannelDifferenceEmpty)):
|
||||||
|
# If there is too much updates (ChannelDifferenceTooLong),
|
||||||
|
# there is no way to get them without raising limit or GetHistoryRequest, so just break
|
||||||
|
break
|
||||||
|
finally:
|
||||||
|
self.session.set_update_state(channel_id, types.updates.State(pts, 0, datetime.fromtimestamp(0), 0, 0))
|
||||||
|
|
||||||
|
async def catch_up(self: 'TelegramClient', pts_total_limit=None, limit=None):
|
||||||
"""
|
"""
|
||||||
"Catches up" on the missed updates while the client was offline.
|
"Catches up" on the missed updates while the client was offline.
|
||||||
You should call this method after registering the event handlers
|
You should call this method after registering the event handlers
|
||||||
|
@ -224,15 +264,22 @@ class UpdateMethods:
|
||||||
|
|
||||||
await client.catch_up()
|
await client.catch_up()
|
||||||
"""
|
"""
|
||||||
pts, date = self._state_cache[None]
|
pts, qts, date = self._state_cache[None]
|
||||||
if not pts:
|
|
||||||
return
|
|
||||||
|
|
||||||
self.session.catching_up = True
|
self.session.catching_up = True
|
||||||
try:
|
try:
|
||||||
|
if not pts:
|
||||||
|
# Ran first time, get initial pts, qts and date and return
|
||||||
|
result = await self(functions.updates.GetStateRequest())
|
||||||
|
pts, qts, date = result.pts, result.qts, result.date
|
||||||
|
return
|
||||||
|
if not qts:
|
||||||
|
qts = 0
|
||||||
|
|
||||||
|
channels_to_fetch = []
|
||||||
while True:
|
while True:
|
||||||
d = await self(functions.updates.GetDifferenceRequest(
|
d = await self(functions.updates.GetDifferenceRequest(
|
||||||
pts, date, 0
|
pts, date, qts, pts_total_limit
|
||||||
))
|
))
|
||||||
if isinstance(d, (types.updates.DifferenceSlice,
|
if isinstance(d, (types.updates.DifferenceSlice,
|
||||||
types.updates.Difference)):
|
types.updates.Difference)):
|
||||||
|
@ -241,43 +288,39 @@ class UpdateMethods:
|
||||||
else:
|
else:
|
||||||
state = d.intermediate_state
|
state = d.intermediate_state
|
||||||
|
|
||||||
pts, date = state.pts, state.date
|
updates = []
|
||||||
|
for update in d.other_updates:
|
||||||
|
if isinstance(update, UpdateChannelTooLong):
|
||||||
|
channels_to_fetch.append((update.channel_id, update.pts))
|
||||||
|
else:
|
||||||
|
updates.append(update)
|
||||||
|
|
||||||
|
pts, qts, date = state.pts, state.qts, state.date
|
||||||
self._handle_update(types.Updates(
|
self._handle_update(types.Updates(
|
||||||
users=d.users,
|
users=d.users,
|
||||||
chats=d.chats,
|
chats=d.chats,
|
||||||
date=state.date,
|
date=state.date,
|
||||||
seq=state.seq,
|
seq=state.seq,
|
||||||
updates=d.other_updates + [
|
updates=updates + [
|
||||||
types.UpdateNewMessage(m, 0, 0)
|
types.UpdateNewMessage(m, 0, 0)
|
||||||
for m in d.new_messages
|
for m in d.new_messages
|
||||||
]
|
]
|
||||||
))
|
))
|
||||||
|
elif isinstance(d, types.updates.DifferenceTooLong):
|
||||||
# TODO Implement upper limit (max_pts)
|
pts = d.pts
|
||||||
# We don't want to fetch updates we already know about.
|
# If the limit isn't set, fetch all updates that we can
|
||||||
#
|
if pts_total_limit is not None:
|
||||||
# We may still get duplicates because the Difference
|
break
|
||||||
# contains a lot of updates and presumably only has
|
elif isinstance(d, types.updates.DifferenceEmpty):
|
||||||
# the state for the last one, but at least we don't
|
date = d.date
|
||||||
# unnecessarily fetch too many.
|
|
||||||
#
|
|
||||||
# updates.getDifference's pts_total_limit seems to mean
|
|
||||||
# "how many pts is the request allowed to return", and
|
|
||||||
# if there is more than that, it returns "too long" (so
|
|
||||||
# there would be duplicate updates since we know about
|
|
||||||
# some). This can be used to detect collisions (i.e.
|
|
||||||
# it would return an update we have already seen).
|
|
||||||
else:
|
|
||||||
if isinstance(d, types.updates.DifferenceEmpty):
|
|
||||||
date = d.date
|
|
||||||
elif isinstance(d, types.updates.DifferenceTooLong):
|
|
||||||
pts = d.pts
|
|
||||||
break
|
break
|
||||||
|
for channel_id, channel_pts in channels_to_fetch:
|
||||||
|
await self._catch_up_channel(channel_id, channel_pts, limit)
|
||||||
except (ConnectionError, asyncio.CancelledError):
|
except (ConnectionError, asyncio.CancelledError):
|
||||||
pass
|
pass
|
||||||
finally:
|
finally:
|
||||||
# TODO Save new pts to session
|
self._state_cache[None] = (pts, qts, date)
|
||||||
self._state_cache._pts_date = (pts, date)
|
self.session.set_update_state(0, types.updates.State(pts, qts, date, seq=0, unread_count=0))
|
||||||
self.session.catching_up = False
|
self.session.catching_up = False
|
||||||
|
|
||||||
# endregion
|
# endregion
|
||||||
|
@ -288,6 +331,9 @@ class UpdateMethods:
|
||||||
# the order that the updates arrive in to update the pts and date to
|
# the order that the updates arrive in to update the pts and date to
|
||||||
# be always-increasing. There is also no need to make this async.
|
# be always-increasing. There is also no need to make this async.
|
||||||
def _handle_update(self: 'TelegramClient', update):
|
def _handle_update(self: 'TelegramClient', update):
|
||||||
|
if self._state_cache.update_already_processed(update):
|
||||||
|
return
|
||||||
|
|
||||||
self.session.process_entities(update)
|
self.session.process_entities(update)
|
||||||
self._entity_cache.add(update)
|
self._entity_cache.add(update)
|
||||||
|
|
||||||
|
@ -304,6 +350,9 @@ class UpdateMethods:
|
||||||
self._state_cache.update(update)
|
self._state_cache.update(update)
|
||||||
|
|
||||||
def _process_update(self: 'TelegramClient', update, others, entities=None):
|
def _process_update(self: 'TelegramClient', update, others, entities=None):
|
||||||
|
if self._state_cache.update_already_processed(update):
|
||||||
|
return
|
||||||
|
|
||||||
update._entities = entities or {}
|
update._entities = entities or {}
|
||||||
|
|
||||||
# This part is somewhat hot so we don't bother patching
|
# This part is somewhat hot so we don't bother patching
|
||||||
|
@ -553,7 +602,7 @@ class UpdateMethods:
|
||||||
if not pts_date[0]:
|
if not pts_date[0]:
|
||||||
# First-time, can't get difference. Get pts instead.
|
# First-time, can't get difference. Get pts instead.
|
||||||
result = await self(functions.updates.GetStateRequest())
|
result = await self(functions.updates.GetStateRequest())
|
||||||
self._state_cache[None] = result.pts, result.date
|
self._state_cache[None] = result.pts, result.qts, result.date
|
||||||
return
|
return
|
||||||
|
|
||||||
result = await self(functions.updates.GetDifferenceRequest(
|
result = await self(functions.updates.GetDifferenceRequest(
|
||||||
|
|
|
@ -88,6 +88,13 @@ class Session(ABC):
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_channel_pts(self):
|
||||||
|
"""
|
||||||
|
Returns the ``Dict[int, int]`` with pts for all saved channels
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def set_update_state(self, entity_id, state):
|
def set_update_state(self, entity_id, state):
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -74,6 +74,9 @@ class MemorySession(Session):
|
||||||
def get_update_state(self, entity_id):
|
def get_update_state(self, entity_id):
|
||||||
return self._update_states.get(entity_id, None)
|
return self._update_states.get(entity_id, None)
|
||||||
|
|
||||||
|
def get_channel_pts(self):
|
||||||
|
return {x[0]: x[1] for x in self._update_states.items() if x[0] != 0}
|
||||||
|
|
||||||
def set_update_state(self, entity_id, state):
|
def set_update_state(self, entity_id, state):
|
||||||
self._update_states[entity_id] = state
|
self._update_states[entity_id] = state
|
||||||
|
|
||||||
|
|
|
@ -210,6 +210,14 @@ class SQLiteSession(MemorySession):
|
||||||
date, tz=datetime.timezone.utc)
|
date, tz=datetime.timezone.utc)
|
||||||
return types.updates.State(pts, qts, date, seq, unread_count=0)
|
return types.updates.State(pts, qts, date, seq, unread_count=0)
|
||||||
|
|
||||||
|
def get_channel_pts(self):
|
||||||
|
c = self._cursor()
|
||||||
|
try:
|
||||||
|
rows = c.execute('select id, pts from update_state').fetchall()
|
||||||
|
finally:
|
||||||
|
c.close()
|
||||||
|
return {x[0]: x[1] for x in rows if x[0] != 0}
|
||||||
|
|
||||||
def set_update_state(self, entity_id, state):
|
def set_update_state(self, entity_id, state):
|
||||||
self._execute('insert or replace into update_state values (?,?,?,?,?)',
|
self._execute('insert or replace into update_state values (?,?,?,?,?)',
|
||||||
entity_id, state.pts, state.qts,
|
entity_id, state.pts, state.qts,
|
||||||
|
|
|
@ -2,7 +2,6 @@ import inspect
|
||||||
|
|
||||||
from .tl import types
|
from .tl import types
|
||||||
|
|
||||||
|
|
||||||
# Which updates have the following fields?
|
# Which updates have the following fields?
|
||||||
_has_channel_id = []
|
_has_channel_id = []
|
||||||
|
|
||||||
|
@ -26,23 +25,70 @@ def _fill():
|
||||||
_fill()
|
_fill()
|
||||||
|
|
||||||
|
|
||||||
|
has_pts = frozenset(x.CONSTRUCTOR_ID for x in (
|
||||||
|
types.UpdateNewMessage,
|
||||||
|
types.UpdateDeleteMessages,
|
||||||
|
types.UpdateReadHistoryInbox,
|
||||||
|
types.UpdateReadHistoryOutbox,
|
||||||
|
types.UpdateWebPage,
|
||||||
|
types.UpdateReadMessagesContents,
|
||||||
|
types.UpdateEditMessage,
|
||||||
|
types.updates.State,
|
||||||
|
types.updates.DifferenceTooLong,
|
||||||
|
types.UpdateShortMessage,
|
||||||
|
types.UpdateShortChatMessage,
|
||||||
|
types.UpdateShortSentMessage
|
||||||
|
))
|
||||||
|
has_qts = frozenset(x.CONSTRUCTOR_ID for x in (
|
||||||
|
types.UpdateBotStopped,
|
||||||
|
types.UpdateNewEncryptedMessage,
|
||||||
|
types.updates.State
|
||||||
|
))
|
||||||
|
has_date = frozenset(x.CONSTRUCTOR_ID for x in (
|
||||||
|
types.UpdateUserPhoto,
|
||||||
|
types.UpdateEncryption,
|
||||||
|
types.UpdateEncryptedMessagesRead,
|
||||||
|
types.UpdateChatParticipantAdd,
|
||||||
|
types.updates.DifferenceEmpty,
|
||||||
|
types.UpdateShortMessage,
|
||||||
|
types.UpdateShortChatMessage,
|
||||||
|
types.UpdateShort,
|
||||||
|
types.UpdatesCombined,
|
||||||
|
types.Updates,
|
||||||
|
types.UpdateShortSentMessage,
|
||||||
|
))
|
||||||
|
has_channel_pts = frozenset(x.CONSTRUCTOR_ID for x in (
|
||||||
|
types.UpdateChannelTooLong,
|
||||||
|
types.UpdateNewChannelMessage,
|
||||||
|
types.UpdateDeleteChannelMessages,
|
||||||
|
types.UpdateEditChannelMessage,
|
||||||
|
types.UpdateChannelWebPage,
|
||||||
|
types.updates.ChannelDifferenceEmpty,
|
||||||
|
types.updates.ChannelDifferenceTooLong,
|
||||||
|
types.updates.ChannelDifference
|
||||||
|
))
|
||||||
|
|
||||||
|
|
||||||
class StateCache:
|
class StateCache:
|
||||||
"""
|
"""
|
||||||
In-memory update state cache, defaultdict-like behaviour.
|
In-memory update state cache, defaultdict-like behaviour.
|
||||||
"""
|
"""
|
||||||
|
_store: dict
|
||||||
|
|
||||||
def __init__(self, initial, loggers):
|
def __init__(self, initial, loggers):
|
||||||
# We only care about the pts and the date. By using a tuple which
|
# We only care about the pts and the date. By using a tuple which
|
||||||
# is lightweight and immutable we can easily copy them around to
|
# is lightweight and immutable we can easily copy them around to
|
||||||
# each update in case they need to fetch missing entities.
|
# each update in case they need to fetch missing entities.
|
||||||
self._logger = loggers[__name__]
|
self._logger = loggers[__name__]
|
||||||
|
self._store = {}
|
||||||
if initial:
|
if initial:
|
||||||
self._pts_date = initial.pts, initial.date
|
self._pts_date = initial.pts, initial.qts, initial.date
|
||||||
else:
|
else:
|
||||||
self._pts_date = None, None
|
self._pts_date = None, None, None
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
self.__dict__.clear()
|
self._store.clear()
|
||||||
self._pts_date = None, None
|
self._pts_date = None, None, None
|
||||||
|
|
||||||
# TODO Call this when receiving responses too...?
|
# TODO Call this when receiving responses too...?
|
||||||
def update(
|
def update(
|
||||||
|
@ -50,43 +96,6 @@ class StateCache:
|
||||||
update,
|
update,
|
||||||
*,
|
*,
|
||||||
channel_id=None,
|
channel_id=None,
|
||||||
has_pts=frozenset(x.CONSTRUCTOR_ID for x in (
|
|
||||||
types.UpdateNewMessage,
|
|
||||||
types.UpdateDeleteMessages,
|
|
||||||
types.UpdateReadHistoryInbox,
|
|
||||||
types.UpdateReadHistoryOutbox,
|
|
||||||
types.UpdateWebPage,
|
|
||||||
types.UpdateReadMessagesContents,
|
|
||||||
types.UpdateEditMessage,
|
|
||||||
types.updates.State,
|
|
||||||
types.updates.DifferenceTooLong,
|
|
||||||
types.UpdateShortMessage,
|
|
||||||
types.UpdateShortChatMessage,
|
|
||||||
types.UpdateShortSentMessage
|
|
||||||
)),
|
|
||||||
has_date=frozenset(x.CONSTRUCTOR_ID for x in (
|
|
||||||
types.UpdateUserPhoto,
|
|
||||||
types.UpdateEncryption,
|
|
||||||
types.UpdateEncryptedMessagesRead,
|
|
||||||
types.UpdateChatParticipantAdd,
|
|
||||||
types.updates.DifferenceEmpty,
|
|
||||||
types.UpdateShortMessage,
|
|
||||||
types.UpdateShortChatMessage,
|
|
||||||
types.UpdateShort,
|
|
||||||
types.UpdatesCombined,
|
|
||||||
types.Updates,
|
|
||||||
types.UpdateShortSentMessage,
|
|
||||||
)),
|
|
||||||
has_channel_pts=frozenset(x.CONSTRUCTOR_ID for x in (
|
|
||||||
types.UpdateChannelTooLong,
|
|
||||||
types.UpdateNewChannelMessage,
|
|
||||||
types.UpdateDeleteChannelMessages,
|
|
||||||
types.UpdateEditChannelMessage,
|
|
||||||
types.UpdateChannelWebPage,
|
|
||||||
types.updates.ChannelDifferenceEmpty,
|
|
||||||
types.updates.ChannelDifferenceTooLong,
|
|
||||||
types.updates.ChannelDifference
|
|
||||||
)),
|
|
||||||
check_only=False
|
check_only=False
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
|
@ -96,13 +105,20 @@ class StateCache:
|
||||||
if check_only:
|
if check_only:
|
||||||
return cid in has_pts or cid in has_date or cid in has_channel_pts
|
return cid in has_pts or cid in has_date or cid in has_channel_pts
|
||||||
|
|
||||||
|
new_pts_date = tuple()
|
||||||
if cid in has_pts:
|
if cid in has_pts:
|
||||||
if cid in has_date:
|
new_pts_date += update.pts,
|
||||||
self._pts_date = update.pts, update.date
|
else:
|
||||||
else:
|
new_pts_date += self._pts_date[0],
|
||||||
self._pts_date = update.pts, self._pts_date[1]
|
if cid in has_qts:
|
||||||
elif cid in has_date:
|
new_pts_date += update.qts,
|
||||||
self._pts_date = self._pts_date[0], update.date
|
else:
|
||||||
|
new_pts_date += self._pts_date[1],
|
||||||
|
if cid in has_date:
|
||||||
|
new_pts_date += update.date,
|
||||||
|
else:
|
||||||
|
new_pts_date += self._pts_date[2],
|
||||||
|
self._pts_date = new_pts_date
|
||||||
|
|
||||||
if cid in has_channel_pts:
|
if cid in has_channel_pts:
|
||||||
if channel_id is None:
|
if channel_id is None:
|
||||||
|
@ -112,7 +128,23 @@ class StateCache:
|
||||||
self._logger.info(
|
self._logger.info(
|
||||||
'Failed to retrieve channel_id from %s', update)
|
'Failed to retrieve channel_id from %s', update)
|
||||||
else:
|
else:
|
||||||
self.__dict__[channel_id] = update.pts
|
self._store[channel_id] = update.pts
|
||||||
|
|
||||||
|
def update_already_processed(self, update):
|
||||||
|
cid = update.CONSTRUCTOR_ID
|
||||||
|
# If pts == 0, the update is from catch_up
|
||||||
|
if cid in has_pts and \
|
||||||
|
update.pts != 0 and \
|
||||||
|
update.pts >= self._pts_date[0]:
|
||||||
|
return True
|
||||||
|
if cid in has_qts and update.qts >= self._pts_date[1]:
|
||||||
|
return True
|
||||||
|
if cid in has_channel_pts:
|
||||||
|
channel_id = self.get_channel_id(update)
|
||||||
|
if update.pts != 0 and \
|
||||||
|
self._store.get(channel_id, 0) >= update.pts:
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
def get_channel_id(
|
def get_channel_id(
|
||||||
self,
|
self,
|
||||||
|
@ -120,8 +152,8 @@ class StateCache:
|
||||||
has_channel_id=frozenset(_has_channel_id),
|
has_channel_id=frozenset(_has_channel_id),
|
||||||
# Hardcoded because only some with message are for channels
|
# Hardcoded because only some with message are for channels
|
||||||
has_message=frozenset(x.CONSTRUCTOR_ID for x in (
|
has_message=frozenset(x.CONSTRUCTOR_ID for x in (
|
||||||
types.UpdateNewChannelMessage,
|
types.UpdateNewChannelMessage,
|
||||||
types.UpdateEditChannelMessage
|
types.UpdateEditChannelMessage
|
||||||
))
|
))
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
|
@ -155,10 +187,13 @@ class StateCache:
|
||||||
if item is None:
|
if item is None:
|
||||||
return self._pts_date
|
return self._pts_date
|
||||||
else:
|
else:
|
||||||
return self.__dict__.get(item)
|
return self._store.get(item)
|
||||||
|
|
||||||
def __setitem__(self, where, value):
|
def __setitem__(self, where, value):
|
||||||
if where is None:
|
if where is None:
|
||||||
self._pts_date = value
|
self._pts_date = value
|
||||||
else:
|
else:
|
||||||
self.__dict__[where] = value
|
self._store[where] = value
|
||||||
|
|
||||||
|
def get_channel_pts(self):
|
||||||
|
return self._store
|
||||||
|
|
Loading…
Reference in New Issue
Block a user