Make use of the async_generator module

This commit is contained in:
Lonami Exo 2018-06-10 21:50:28 +02:00
parent 15ef302428
commit 8be6adeab4
3 changed files with 21 additions and 11 deletions

View File

@ -1,5 +1,7 @@
from collections import UserList from collections import UserList
from async_generator import async_generator, yield_
from .users import UserMethods from .users import UserMethods
from .. import utils from .. import utils
from ..tl import types, functions from ..tl import types, functions
@ -9,6 +11,7 @@ class ChatMethods(UserMethods):
# region Public methods # region Public methods
@async_generator
async def iter_participants( async def iter_participants(
self, entity, limit=None, search='', self, entity, limit=None, search='',
filter=None, aggressive=False, _total=None): filter=None, aggressive=False, _total=None):
@ -130,7 +133,7 @@ class ChatMethods(UserMethods):
seen.add(participant.user_id) seen.add(participant.user_id)
user = users[participant.user_id] user = users[participant.user_id]
user.participant = participant user.participant = participant
yield user await yield_(user)
if len(seen) >= limit: if len(seen) >= limit:
return return
@ -159,7 +162,7 @@ class ChatMethods(UserMethods):
else: else:
user = users[participant.user_id] user = users[participant.user_id]
user.participant = participant user.participant = participant
yield user await yield_(user)
else: else:
if _total: if _total:
_total[0] = 1 _total[0] = 1
@ -167,7 +170,7 @@ class ChatMethods(UserMethods):
user = await self.get_entity(entity) user = await self.get_entity(entity)
if filter_entity(user): if filter_entity(user):
user.participant = None user.participant = None
yield user await yield_(user)
async def get_participants(self, *args, **kwargs): async def get_participants(self, *args, **kwargs):
""" """

View File

@ -1,14 +1,17 @@
import itertools import itertools
from collections import UserList from collections import UserList
from async_generator import async_generator, yield_
from .users import UserMethods from .users import UserMethods
from ..tl import types, functions, custom
from .. import utils from .. import utils
from ..tl import types, functions, custom
class DialogMethods(UserMethods): class DialogMethods(UserMethods):
# region Public methods # region Public methods
@async_generator
async def iter_dialogs( async def iter_dialogs(
self, limit=None, offset_date=None, offset_id=0, self, limit=None, offset_date=None, offset_id=0,
offset_peer=types.InputPeerEmpty(), _total=None): offset_peer=types.InputPeerEmpty(), _total=None):
@ -80,7 +83,7 @@ class DialogMethods(UserMethods):
peer_id = utils.get_peer_id(d.peer) peer_id = utils.get_peer_id(d.peer)
if peer_id not in seen: if peer_id not in seen:
seen.add(peer_id) seen.add(peer_id)
yield custom.Dialog(self, d, entities, messages) await yield_(custom.Dialog(self, d, entities, messages))
if len(r.dialogs) < req.limit\ if len(r.dialogs) < req.limit\
or not isinstance(r, types.messages.DialogsSlice): or not isinstance(r, types.messages.DialogsSlice):
@ -115,7 +118,7 @@ class DialogMethods(UserMethods):
""" """
r = await self(functions.messages.GetAllDraftsRequest()) r = await self(functions.messages.GetAllDraftsRequest())
for update in r.updates: for update in r.updates:
yield custom.Draft._from_update(self, update) await yield_(custom.Draft._from_update(self, update))
async def get_drafts(self): async def get_drafts(self):
""" """

View File

@ -5,6 +5,8 @@ import time
import warnings import warnings
from collections import UserList from collections import UserList
from async_generator import async_generator, yield_
from .messageparse import MessageParseMethods from .messageparse import MessageParseMethods
from .uploads import UploadMethods from .uploads import UploadMethods
from .. import utils from .. import utils
@ -19,6 +21,7 @@ class MessageMethods(UploadMethods, MessageParseMethods):
# region Message retrieval # region Message retrieval
@async_generator
async def iter_messages( async def iter_messages(
self, entity, limit=None, offset_date=None, offset_id=0, self, entity, limit=None, offset_date=None, offset_id=0,
max_id=0, min_id=0, add_offset=0, search=None, filter=None, max_id=0, min_id=0, add_offset=0, search=None, filter=None,
@ -114,7 +117,7 @@ class MessageMethods(UploadMethods, MessageParseMethods):
if not utils.is_list_like(ids): if not utils.is_list_like(ids):
ids = (ids,) ids = (ids,)
async for x in self._iter_ids(entity, ids, total=_total): async for x in self._iter_ids(entity, ids, total=_total):
yield x await yield_(x)
return return
# Telegram doesn't like min_id/max_id. If these IDs are low enough # Telegram doesn't like min_id/max_id. If these IDs are low enough
@ -202,7 +205,7 @@ class MessageMethods(UploadMethods, MessageParseMethods):
# IDs are returned in descending order. # IDs are returned in descending order.
last_id = message.id last_id = message.id
yield custom.Message(self, message, entities, entity) await yield_(custom.Message(self, message, entities, entity))
have += 1 have += 1
if len(r.messages) < request.limit: if len(r.messages) < request.limit:
@ -620,6 +623,7 @@ class MessageMethods(UploadMethods, MessageParseMethods):
# region Private methods # region Private methods
@async_generator
async def _iter_ids(self, entity, ids, total): async def _iter_ids(self, entity, ids, total):
""" """
Special case for `iter_messages` when it should only fetch some IDs. Special case for `iter_messages` when it should only fetch some IDs.
@ -634,7 +638,7 @@ class MessageMethods(UploadMethods, MessageParseMethods):
if isinstance(r, types.messages.MessagesNotModified): if isinstance(r, types.messages.MessagesNotModified):
for _ in ids: for _ in ids:
yield None await yield_(None)
return return
entities = {utils.get_peer_id(x): x entities = {utils.get_peer_id(x): x
@ -644,8 +648,8 @@ class MessageMethods(UploadMethods, MessageParseMethods):
# we asked them for, so we don't need to check it ourselves. # we asked them for, so we don't need to check it ourselves.
for message in r.messages: for message in r.messages:
if isinstance(message, types.MessageEmpty): if isinstance(message, types.MessageEmpty):
yield None await yield_(None)
else: else:
yield custom.Message(self, message, entities, entity) await yield_(custom.Message(self, message, entities, entity))
# endregion # endregion