Use RequestIter in the dialog methods

This commit is contained in:
Lonami Exo 2019-02-27 10:37:40 +01:00
parent 49d8a3fb33
commit 968da5f72d

View File

@ -1,18 +1,105 @@
import itertools import itertools
from async_generator import async_generator, yield_
from .users import UserMethods from .users import UserMethods
from .. import utils, helpers from .. import utils
from ..requestiter import RequestIter
from ..tl import types, functions, custom from ..tl import types, functions, custom
class _DialogsIter(RequestIter):
async def _init(
self, offset_date, offset_id, offset_peer, ignore_migrated
):
self.request = functions.messages.GetDialogsRequest(
offset_date=offset_date,
offset_id=offset_id,
offset_peer=offset_peer,
limit=1,
hash=0
)
if self.limit == 0:
# Special case, get a single dialog and determine count
dialogs = await self.client(self.request)
self.total = getattr(dialogs, 'count', len(dialogs.dialogs))
raise StopAsyncIteration
self.seen = set()
self.offset_date = offset_date
self.ignore_migrated = ignore_migrated
async def _load_next_chunk(self):
result = []
self.request.limit = min(self.left, 100)
r = await self.client(self.request)
self.total = getattr(r, 'count', len(r.dialogs))
entities = {utils.get_peer_id(x): x
for x in itertools.chain(r.users, r.chats)}
messages = {}
for m in r.messages:
m._finish_init(self, entities, None)
messages[m.id] = m
for d in r.dialogs:
# We check the offset date here because Telegram may ignore it
if self.offset_date:
date = getattr(messages.get(
d.top_message, None), 'date', None)
if not date or date.timestamp() > self.offset_date.timestamp():
continue
peer_id = utils.get_peer_id(d.peer)
if peer_id not in self.seen:
self.seen.add(peer_id)
cd = custom.Dialog(self, d, entities, messages)
if cd.dialog.pts:
self.client._channel_pts[cd.id] = cd.dialog.pts
if not self.ignore_migrated or getattr(
cd.entity, 'migrated_to', None) is None:
result.append(cd)
if len(r.dialogs) < self.request.limit\
or not isinstance(r, types.messages.DialogsSlice):
# Less than we requested means we reached the end, or
# we didn't get a DialogsSlice which means we got all.
self.left = len(result)
self.request.offset_date = r.messages[-1].date
self.request.offset_peer =\
entities[utils.get_peer_id(r.dialogs[-1].peer)]
if self.request.offset_id == r.messages[-1].id:
# In some very rare cases this will get stuck in an infinite
# loop, where the offsets will get reused over and over. If
# the new offset is the same as the one before, break already.
self.left = len(result)
self.request.offset_id = r.messages[-1].id
self.request.exclude_pinned = True
return result
class _DraftsIter(RequestIter):
async def _init(self, **kwargs):
r = await self.client(functions.messages.GetAllDraftsRequest())
self.buffer = [custom.Draft._from_update(self.client, u)
for u in r.updates]
async def _load_next_chunk(self):
return []
class DialogMethods(UserMethods): class DialogMethods(UserMethods):
# region Public methods # region Public methods
@async_generator def iter_dialogs(
async def iter_dialogs(
self, limit=None, *, offset_date=None, offset_id=0, self, limit=None, *, offset_date=None, offset_id=0,
offset_peer=types.InputPeerEmpty(), ignore_migrated=False, offset_peer=types.InputPeerEmpty(), ignore_migrated=False,
_total=None): _total=None):
@ -50,99 +137,23 @@ class DialogMethods(UserMethods):
Yields: Yields:
Instances of `telethon.tl.custom.dialog.Dialog`. Instances of `telethon.tl.custom.dialog.Dialog`.
""" """
limit = float('inf') if limit is None else int(limit) return _DialogsIter(
if limit == 0: self,
if not _total: limit,
return
# Special case, get a single dialog and determine count
dialogs = await self(functions.messages.GetDialogsRequest(
offset_date=offset_date, offset_date=offset_date,
offset_id=offset_id, offset_id=offset_id,
offset_peer=offset_peer, offset_peer=offset_peer,
limit=1, ignore_migrated=ignore_migrated
hash=0
))
_total[0] = getattr(dialogs, 'count', len(dialogs.dialogs))
return
seen = set()
req = functions.messages.GetDialogsRequest(
offset_date=offset_date,
offset_id=offset_id,
offset_peer=offset_peer,
limit=0,
hash=0
) )
while len(seen) < limit:
req.limit = min(limit - len(seen), 100)
r = await self(req)
if _total:
_total[0] = getattr(r, 'count', len(r.dialogs))
entities = {utils.get_peer_id(x): x
for x in itertools.chain(r.users, r.chats)}
messages = {}
for m in r.messages:
m._finish_init(self, entities, None)
messages[m.id] = m
# Happens when there are pinned dialogs
if len(r.dialogs) > limit:
r.dialogs = r.dialogs[:limit]
for d in r.dialogs:
if offset_date:
date = getattr(messages.get(
d.top_message, None), 'date', None)
if not date or date.timestamp() > offset_date.timestamp():
continue
peer_id = utils.get_peer_id(d.peer)
if peer_id not in seen:
seen.add(peer_id)
cd = custom.Dialog(self, d, entities, messages)
if cd.dialog.pts:
self._channel_pts[cd.id] = cd.dialog.pts
if not ignore_migrated or getattr(
cd.entity, 'migrated_to', None) is None:
await yield_(cd)
if len(r.dialogs) < req.limit\
or not isinstance(r, types.messages.DialogsSlice):
# Less than we requested means we reached the end, or
# we didn't get a DialogsSlice which means we got all.
break
req.offset_date = r.messages[-1].date
req.offset_peer = entities[utils.get_peer_id(r.dialogs[-1].peer)]
if req.offset_id == r.messages[-1].id:
# In some very rare cases this will get stuck in an infinite
# loop, where the offsets will get reused over and over. If
# the new offset is the same as the one before, break already.
break
req.offset_id = r.messages[-1].id
req.exclude_pinned = True
async def get_dialogs(self, *args, **kwargs): async def get_dialogs(self, *args, **kwargs):
""" """
Same as `iter_dialogs`, but returns a Same as `iter_dialogs`, but returns a
`TotalList <telethon.helpers.TotalList>` instead. `TotalList <telethon.helpers.TotalList>` instead.
""" """
total = [0] return await self.iter_dialogs(*args, **kwargs).collect()
kwargs['_total'] = total
dialogs = helpers.TotalList()
async for x in self.iter_dialogs(*args, **kwargs):
dialogs.append(x)
dialogs.total = total[0]
return dialogs
@async_generator def iter_drafts(self):
async def iter_drafts(self):
""" """
Iterator over all open draft messages. Iterator over all open draft messages.
@ -151,18 +162,14 @@ class DialogMethods(UserMethods):
to change the message or `telethon.tl.custom.draft.Draft.delete` to change the message or `telethon.tl.custom.draft.Draft.delete`
among other things. among other things.
""" """
r = await self(functions.messages.GetAllDraftsRequest()) # TODO Passing a limit here makes no sense
for update in r.updates: return _DraftsIter(self, None)
await yield_(custom.Draft._from_update(self, update))
async def get_drafts(self): async def get_drafts(self):
""" """
Same as :meth:`iter_drafts`, but returns a list instead. Same as :meth:`iter_drafts`, but returns a list instead.
""" """
result = [] return await self.iter_drafts().collect()
async for x in self.iter_drafts():
result.append(x)
return result
def conversation( def conversation(
self, entity, self, entity,