Make use of the async_generator module

This commit is contained in:
Lonami Exo 2018-06-10 21:50:28 +02:00
parent 15ef302428
commit 8be6adeab4
3 changed files with 21 additions and 11 deletions

View File

@ -1,5 +1,7 @@
from collections import UserList
from async_generator import async_generator, yield_
from .users import UserMethods
from .. import utils
from ..tl import types, functions
@ -9,6 +11,7 @@ class ChatMethods(UserMethods):
# region Public methods
@async_generator
async def iter_participants(
self, entity, limit=None, search='',
filter=None, aggressive=False, _total=None):
@ -130,7 +133,7 @@ class ChatMethods(UserMethods):
seen.add(participant.user_id)
user = users[participant.user_id]
user.participant = participant
yield user
await yield_(user)
if len(seen) >= limit:
return
@ -159,7 +162,7 @@ class ChatMethods(UserMethods):
else:
user = users[participant.user_id]
user.participant = participant
yield user
await yield_(user)
else:
if _total:
_total[0] = 1
@ -167,7 +170,7 @@ class ChatMethods(UserMethods):
user = await self.get_entity(entity)
if filter_entity(user):
user.participant = None
yield user
await yield_(user)
async def get_participants(self, *args, **kwargs):
"""

View File

@ -1,14 +1,17 @@
import itertools
from collections import UserList
from async_generator import async_generator, yield_
from .users import UserMethods
from ..tl import types, functions, custom
from .. import utils
from ..tl import types, functions, custom
class DialogMethods(UserMethods):
# region Public methods
@async_generator
async def iter_dialogs(
self, limit=None, offset_date=None, offset_id=0,
offset_peer=types.InputPeerEmpty(), _total=None):
@ -80,7 +83,7 @@ class DialogMethods(UserMethods):
peer_id = utils.get_peer_id(d.peer)
if peer_id not in seen:
seen.add(peer_id)
yield custom.Dialog(self, d, entities, messages)
await yield_(custom.Dialog(self, d, entities, messages))
if len(r.dialogs) < req.limit\
or not isinstance(r, types.messages.DialogsSlice):
@ -115,7 +118,7 @@ class DialogMethods(UserMethods):
"""
r = await self(functions.messages.GetAllDraftsRequest())
for update in r.updates:
yield custom.Draft._from_update(self, update)
await yield_(custom.Draft._from_update(self, update))
async def get_drafts(self):
"""

View File

@ -5,6 +5,8 @@ import time
import warnings
from collections import UserList
from async_generator import async_generator, yield_
from .messageparse import MessageParseMethods
from .uploads import UploadMethods
from .. import utils
@ -19,6 +21,7 @@ class MessageMethods(UploadMethods, MessageParseMethods):
# region Message retrieval
@async_generator
async def iter_messages(
self, entity, limit=None, offset_date=None, offset_id=0,
max_id=0, min_id=0, add_offset=0, search=None, filter=None,
@ -114,7 +117,7 @@ class MessageMethods(UploadMethods, MessageParseMethods):
if not utils.is_list_like(ids):
ids = (ids,)
async for x in self._iter_ids(entity, ids, total=_total):
yield x
await yield_(x)
return
# Telegram doesn't like min_id/max_id. If these IDs are low enough
@ -202,7 +205,7 @@ class MessageMethods(UploadMethods, MessageParseMethods):
# IDs are returned in descending order.
last_id = message.id
yield custom.Message(self, message, entities, entity)
await yield_(custom.Message(self, message, entities, entity))
have += 1
if len(r.messages) < request.limit:
@ -620,6 +623,7 @@ class MessageMethods(UploadMethods, MessageParseMethods):
# region Private methods
@async_generator
async def _iter_ids(self, entity, ids, total):
"""
Special case for `iter_messages` when it should only fetch some IDs.
@ -634,7 +638,7 @@ class MessageMethods(UploadMethods, MessageParseMethods):
if isinstance(r, types.messages.MessagesNotModified):
for _ in ids:
yield None
await yield_(None)
return
entities = {utils.get_peer_id(x): x
@ -644,8 +648,8 @@ class MessageMethods(UploadMethods, MessageParseMethods):
# we asked them for, so we don't need to check it ourselves.
for message in r.messages:
if isinstance(message, types.MessageEmpty):
yield None
await yield_(None)
else:
yield custom.Message(self, message, entities, entity)
await yield_(custom.Message(self, message, entities, entity))
# endregion