Use RequestIter in the dialog methods

This commit is contained in:
Lonami Exo 2019-02-27 10:37:40 +01:00
parent 49d8a3fb33
commit 968da5f72d

View File

@ -1,18 +1,105 @@
import itertools
from async_generator import async_generator, yield_
from .users import UserMethods
from .. import utils, helpers
from .. import utils
from ..requestiter import RequestIter
from ..tl import types, functions, custom
class _DialogsIter(RequestIter):
async def _init(
self, offset_date, offset_id, offset_peer, ignore_migrated
):
self.request = functions.messages.GetDialogsRequest(
offset_date=offset_date,
offset_id=offset_id,
offset_peer=offset_peer,
limit=1,
hash=0
)
if self.limit == 0:
# Special case, get a single dialog and determine count
dialogs = await self.client(self.request)
self.total = getattr(dialogs, 'count', len(dialogs.dialogs))
raise StopAsyncIteration
self.seen = set()
self.offset_date = offset_date
self.ignore_migrated = ignore_migrated
async def _load_next_chunk(self):
result = []
self.request.limit = min(self.left, 100)
r = await self.client(self.request)
self.total = getattr(r, 'count', len(r.dialogs))
entities = {utils.get_peer_id(x): x
for x in itertools.chain(r.users, r.chats)}
messages = {}
for m in r.messages:
m._finish_init(self, entities, None)
messages[m.id] = m
for d in r.dialogs:
# We check the offset date here because Telegram may ignore it
if self.offset_date:
date = getattr(messages.get(
d.top_message, None), 'date', None)
if not date or date.timestamp() > self.offset_date.timestamp():
continue
peer_id = utils.get_peer_id(d.peer)
if peer_id not in self.seen:
self.seen.add(peer_id)
cd = custom.Dialog(self, d, entities, messages)
if cd.dialog.pts:
self.client._channel_pts[cd.id] = cd.dialog.pts
if not self.ignore_migrated or getattr(
cd.entity, 'migrated_to', None) is None:
result.append(cd)
if len(r.dialogs) < self.request.limit\
or not isinstance(r, types.messages.DialogsSlice):
# Less than we requested means we reached the end, or
# we didn't get a DialogsSlice which means we got all.
self.left = len(result)
self.request.offset_date = r.messages[-1].date
self.request.offset_peer =\
entities[utils.get_peer_id(r.dialogs[-1].peer)]
if self.request.offset_id == r.messages[-1].id:
# In some very rare cases this will get stuck in an infinite
# loop, where the offsets will get reused over and over. If
# the new offset is the same as the one before, break already.
self.left = len(result)
self.request.offset_id = r.messages[-1].id
self.request.exclude_pinned = True
return result
class _DraftsIter(RequestIter):
async def _init(self, **kwargs):
r = await self.client(functions.messages.GetAllDraftsRequest())
self.buffer = [custom.Draft._from_update(self.client, u)
for u in r.updates]
async def _load_next_chunk(self):
return []
class DialogMethods(UserMethods):
# region Public methods
@async_generator
async def iter_dialogs(
def iter_dialogs(
self, limit=None, *, offset_date=None, offset_id=0,
offset_peer=types.InputPeerEmpty(), ignore_migrated=False,
_total=None):
@ -50,99 +137,23 @@ class DialogMethods(UserMethods):
Yields:
Instances of `telethon.tl.custom.dialog.Dialog`.
"""
limit = float('inf') if limit is None else int(limit)
if limit == 0:
if not _total:
return
# Special case, get a single dialog and determine count
dialogs = await self(functions.messages.GetDialogsRequest(
offset_date=offset_date,
offset_id=offset_id,
offset_peer=offset_peer,
limit=1,
hash=0
))
_total[0] = getattr(dialogs, 'count', len(dialogs.dialogs))
return
seen = set()
req = functions.messages.GetDialogsRequest(
return _DialogsIter(
self,
limit,
offset_date=offset_date,
offset_id=offset_id,
offset_peer=offset_peer,
limit=0,
hash=0
ignore_migrated=ignore_migrated
)
while len(seen) < limit:
req.limit = min(limit - len(seen), 100)
r = await self(req)
if _total:
_total[0] = getattr(r, 'count', len(r.dialogs))
entities = {utils.get_peer_id(x): x
for x in itertools.chain(r.users, r.chats)}
messages = {}
for m in r.messages:
m._finish_init(self, entities, None)
messages[m.id] = m
# Happens when there are pinned dialogs
if len(r.dialogs) > limit:
r.dialogs = r.dialogs[:limit]
for d in r.dialogs:
if offset_date:
date = getattr(messages.get(
d.top_message, None), 'date', None)
if not date or date.timestamp() > offset_date.timestamp():
continue
peer_id = utils.get_peer_id(d.peer)
if peer_id not in seen:
seen.add(peer_id)
cd = custom.Dialog(self, d, entities, messages)
if cd.dialog.pts:
self._channel_pts[cd.id] = cd.dialog.pts
if not ignore_migrated or getattr(
cd.entity, 'migrated_to', None) is None:
await yield_(cd)
if len(r.dialogs) < req.limit\
or not isinstance(r, types.messages.DialogsSlice):
# Less than we requested means we reached the end, or
# we didn't get a DialogsSlice which means we got all.
break
req.offset_date = r.messages[-1].date
req.offset_peer = entities[utils.get_peer_id(r.dialogs[-1].peer)]
if req.offset_id == r.messages[-1].id:
# In some very rare cases this will get stuck in an infinite
# loop, where the offsets will get reused over and over. If
# the new offset is the same as the one before, break already.
break
req.offset_id = r.messages[-1].id
req.exclude_pinned = True
async def get_dialogs(self, *args, **kwargs):
"""
Same as `iter_dialogs`, but returns a
`TotalList <telethon.helpers.TotalList>` instead.
"""
total = [0]
kwargs['_total'] = total
dialogs = helpers.TotalList()
async for x in self.iter_dialogs(*args, **kwargs):
dialogs.append(x)
dialogs.total = total[0]
return dialogs
return await self.iter_dialogs(*args, **kwargs).collect()
@async_generator
async def iter_drafts(self):
def iter_drafts(self):
"""
Iterator over all open draft messages.
@ -151,18 +162,14 @@ class DialogMethods(UserMethods):
to change the message or `telethon.tl.custom.draft.Draft.delete`
among other things.
"""
r = await self(functions.messages.GetAllDraftsRequest())
for update in r.updates:
await yield_(custom.Draft._from_update(self, update))
# TODO Passing a limit here makes no sense
return _DraftsIter(self, None)
async def get_drafts(self):
"""
Same as :meth:`iter_drafts`, but returns a list instead.
"""
result = []
async for x in self.iter_drafts():
result.append(x)
return result
return await self.iter_drafts().collect()
def conversation(
self, entity,