diff --git a/telethon/client/dialogs.py b/telethon/client/dialogs.py index 6c7edd24..bf3f903f 100644 --- a/telethon/client/dialogs.py +++ b/telethon/client/dialogs.py @@ -1,18 +1,105 @@ import itertools -from async_generator import async_generator, yield_ - from .users import UserMethods -from .. import utils, helpers +from .. import utils +from ..requestiter import RequestIter from ..tl import types, functions, custom +class _DialogsIter(RequestIter): + async def _init( + self, offset_date, offset_id, offset_peer, ignore_migrated + ): + self.request = functions.messages.GetDialogsRequest( + offset_date=offset_date, + offset_id=offset_id, + offset_peer=offset_peer, + limit=1, + hash=0 + ) + + if self.limit == 0: + # Special case, get a single dialog and determine count + dialogs = await self.client(self.request) + self.total = getattr(dialogs, 'count', len(dialogs.dialogs)) + raise StopAsyncIteration + + self.seen = set() + self.offset_date = offset_date + self.ignore_migrated = ignore_migrated + + async def _load_next_chunk(self): + result = [] + + self.request.limit = min(self.left, 100) + r = await self.client(self.request) + + self.total = getattr(r, 'count', len(r.dialogs)) + + entities = {utils.get_peer_id(x): x + for x in itertools.chain(r.users, r.chats)} + + messages = {} + for m in r.messages: + m._finish_init(self, entities, None) + messages[m.id] = m + + for d in r.dialogs: + # We check the offset date here because Telegram may ignore it + if self.offset_date: + date = getattr(messages.get( + d.top_message, None), 'date', None) + + if not date or date.timestamp() > self.offset_date.timestamp(): + continue + + peer_id = utils.get_peer_id(d.peer) + if peer_id not in self.seen: + self.seen.add(peer_id) + cd = custom.Dialog(self, d, entities, messages) + if cd.dialog.pts: + self.client._channel_pts[cd.id] = cd.dialog.pts + + if not self.ignore_migrated or getattr( + cd.entity, 'migrated_to', None) is None: + result.append(cd) + + if len(r.dialogs) < self.request.limit\ + or not isinstance(r, types.messages.DialogsSlice): + # Less than we requested means we reached the end, or + # we didn't get a DialogsSlice which means we got all. + self.left = len(result) + + self.request.offset_date = r.messages[-1].date + self.request.offset_peer =\ + entities[utils.get_peer_id(r.dialogs[-1].peer)] + + if self.request.offset_id == r.messages[-1].id: + # In some very rare cases this will get stuck in an infinite + # loop, where the offsets will get reused over and over. If + # the new offset is the same as the one before, break already. + self.left = len(result) + + self.request.offset_id = r.messages[-1].id + self.request.exclude_pinned = True + return result + + +class _DraftsIter(RequestIter): + async def _init(self, **kwargs): + r = await self.client(functions.messages.GetAllDraftsRequest()) + self.buffer = [custom.Draft._from_update(self.client, u) + for u in r.updates] + + async def _load_next_chunk(self): + return [] + + class DialogMethods(UserMethods): # region Public methods - @async_generator - async def iter_dialogs( + def iter_dialogs( self, limit=None, *, offset_date=None, offset_id=0, offset_peer=types.InputPeerEmpty(), ignore_migrated=False, _total=None): @@ -50,99 +137,23 @@ class DialogMethods(UserMethods): Yields: Instances of `telethon.tl.custom.dialog.Dialog`. """ - limit = float('inf') if limit is None else int(limit) - if limit == 0: - if not _total: - return - # Special case, get a single dialog and determine count - dialogs = await self(functions.messages.GetDialogsRequest( - offset_date=offset_date, - offset_id=offset_id, - offset_peer=offset_peer, - limit=1, - hash=0 - )) - _total[0] = getattr(dialogs, 'count', len(dialogs.dialogs)) - return - - seen = set() - req = functions.messages.GetDialogsRequest( + return _DialogsIter( + self, + limit, offset_date=offset_date, offset_id=offset_id, offset_peer=offset_peer, - limit=0, - hash=0 + ignore_migrated=ignore_migrated ) - while len(seen) < limit: - req.limit = min(limit - len(seen), 100) - r = await self(req) - - if _total: - _total[0] = getattr(r, 'count', len(r.dialogs)) - - entities = {utils.get_peer_id(x): x - for x in itertools.chain(r.users, r.chats)} - - messages = {} - for m in r.messages: - m._finish_init(self, entities, None) - messages[m.id] = m - - # Happens when there are pinned dialogs - if len(r.dialogs) > limit: - r.dialogs = r.dialogs[:limit] - - for d in r.dialogs: - if offset_date: - date = getattr(messages.get( - d.top_message, None), 'date', None) - - if not date or date.timestamp() > offset_date.timestamp(): - continue - - peer_id = utils.get_peer_id(d.peer) - if peer_id not in seen: - seen.add(peer_id) - cd = custom.Dialog(self, d, entities, messages) - if cd.dialog.pts: - self._channel_pts[cd.id] = cd.dialog.pts - - if not ignore_migrated or getattr( - cd.entity, 'migrated_to', None) is None: - await yield_(cd) - - if len(r.dialogs) < req.limit\ - or not isinstance(r, types.messages.DialogsSlice): - # Less than we requested means we reached the end, or - # we didn't get a DialogsSlice which means we got all. - break - - req.offset_date = r.messages[-1].date - req.offset_peer = entities[utils.get_peer_id(r.dialogs[-1].peer)] - if req.offset_id == r.messages[-1].id: - # In some very rare cases this will get stuck in an infinite - # loop, where the offsets will get reused over and over. If - # the new offset is the same as the one before, break already. - break - - req.offset_id = r.messages[-1].id - req.exclude_pinned = True async def get_dialogs(self, *args, **kwargs): """ Same as `iter_dialogs`, but returns a `TotalList ` instead. """ - total = [0] - kwargs['_total'] = total - dialogs = helpers.TotalList() - async for x in self.iter_dialogs(*args, **kwargs): - dialogs.append(x) - dialogs.total = total[0] - return dialogs + return await self.iter_dialogs(*args, **kwargs).collect() - @async_generator - async def iter_drafts(self): + def iter_drafts(self): """ Iterator over all open draft messages. @@ -151,18 +162,14 @@ class DialogMethods(UserMethods): to change the message or `telethon.tl.custom.draft.Draft.delete` among other things. """ - r = await self(functions.messages.GetAllDraftsRequest()) - for update in r.updates: - await yield_(custom.Draft._from_update(self, update)) + # TODO Passing a limit here makes no sense + return _DraftsIter(self, None) async def get_drafts(self): """ Same as :meth:`iter_drafts`, but returns a list instead. """ - result = [] - async for x in self.iter_drafts(): - result.append(x) - return result + return await self.iter_drafts().collect() def conversation( self, entity,