Fix generators

This commit is contained in:
Lonami Exo 2018-06-28 15:32:18 +02:00
parent e1f8807d83
commit f41b41696a
4 changed files with 11 additions and 23 deletions

View File

@ -222,8 +222,7 @@ def main():
packages=find_packages(exclude=[ packages=find_packages(exclude=[
'telethon_*', 'run_tests.py', 'try_telethon.py' 'telethon_*', 'run_tests.py', 'try_telethon.py'
]), ]),
install_requires=['pyaes', 'rsa', install_requires=['pyaes', 'rsa'],
'async_generator'],
extras_require={ extras_require={
'cryptg': ['cryptg'] 'cryptg': ['cryptg']
} }

View File

@ -1,7 +1,5 @@
from collections import UserList from collections import UserList
from async_generator import async_generator, yield_
from .users import UserMethods from .users import UserMethods
from .. import utils from .. import utils
from ..tl import types, functions from ..tl import types, functions
@ -11,7 +9,6 @@ class ChatMethods(UserMethods):
# region Public methods # region Public methods
@async_generator
def iter_participants( def iter_participants(
self, entity, limit=None, *, search='', self, entity, limit=None, *, search='',
filter=None, aggressive=False, _total=None): filter=None, aggressive=False, _total=None):
@ -133,7 +130,7 @@ class ChatMethods(UserMethods):
seen.add(participant.user_id) seen.add(participant.user_id)
user = users[participant.user_id] user = users[participant.user_id]
user.participant = participant user.participant = participant
yield_(user) yield (user)
if len(seen) >= limit: if len(seen) >= limit:
return return
@ -161,7 +158,7 @@ class ChatMethods(UserMethods):
else: else:
user = users[participant.user_id] user = users[participant.user_id]
user.participant = participant user.participant = participant
yield_(user) yield (user)
else: else:
if _total: if _total:
_total[0] = 1 _total[0] = 1
@ -169,7 +166,7 @@ class ChatMethods(UserMethods):
user = self.get_entity(entity) user = self.get_entity(entity)
if filter_entity(user): if filter_entity(user):
user.participant = None user.participant = None
yield_(user) yield (user)
def get_participants(self, *args, **kwargs): def get_participants(self, *args, **kwargs):
""" """

View File

@ -1,8 +1,6 @@
import itertools import itertools
from collections import UserList from collections import UserList
from async_generator import async_generator, yield_
from .users import UserMethods from .users import UserMethods
from .. import utils from .. import utils
from ..tl import types, functions, custom from ..tl import types, functions, custom
@ -12,7 +10,6 @@ class DialogMethods(UserMethods):
# region Public methods # region Public methods
@async_generator
def iter_dialogs( def iter_dialogs(
self, limit=None, *, offset_date=None, offset_id=0, self, limit=None, *, offset_date=None, offset_id=0,
offset_peer=types.InputPeerEmpty(), ignore_migrated=False, offset_peer=types.InputPeerEmpty(), ignore_migrated=False,
@ -97,7 +94,7 @@ class DialogMethods(UserMethods):
if not ignore_migrated or getattr( if not ignore_migrated or getattr(
cd.entity, 'migrated_to', None) is None: cd.entity, 'migrated_to', None) is None:
yield_(cd) yield (cd)
if len(r.dialogs) < req.limit\ if len(r.dialogs) < req.limit\
or not isinstance(r, types.messages.DialogsSlice): or not isinstance(r, types.messages.DialogsSlice):
@ -129,7 +126,6 @@ class DialogMethods(UserMethods):
dialogs.total = total[0] dialogs.total = total[0]
return dialogs return dialogs
@async_generator
def iter_drafts(self): def iter_drafts(self):
""" """
Iterator over all open draft messages. Iterator over all open draft messages.
@ -141,7 +137,7 @@ class DialogMethods(UserMethods):
""" """
r = self(functions.messages.GetAllDraftsRequest()) r = self(functions.messages.GetAllDraftsRequest())
for update in r.updates: for update in r.updates:
yield_(custom.Draft._from_update(self, update)) yield (custom.Draft._from_update(self, update))
def get_drafts(self): def get_drafts(self):
""" """

View File

@ -4,8 +4,6 @@ import logging
import time import time
from collections import UserList from collections import UserList
from async_generator import async_generator, yield_
from .messageparse import MessageParseMethods from .messageparse import MessageParseMethods
from .uploads import UploadMethods from .uploads import UploadMethods
from .. import utils from .. import utils
@ -20,7 +18,6 @@ class MessageMethods(UploadMethods, MessageParseMethods):
# region Message retrieval # region Message retrieval
@async_generator
def iter_messages( def iter_messages(
self, entity, limit=None, *, offset_date=None, offset_id=0, self, entity, limit=None, *, offset_date=None, offset_id=0,
max_id=0, min_id=0, add_offset=0, search=None, filter=None, max_id=0, min_id=0, add_offset=0, search=None, filter=None,
@ -116,7 +113,7 @@ class MessageMethods(UploadMethods, MessageParseMethods):
if not utils.is_list_like(ids): if not utils.is_list_like(ids):
ids = (ids,) ids = (ids,)
for x in self._iter_ids(entity, ids, total=_total): for x in self._iter_ids(entity, ids, total=_total):
yield_(x) yield (x)
return return
# Telegram doesn't like min_id/max_id. If these IDs are low enough # Telegram doesn't like min_id/max_id. If these IDs are low enough
@ -213,7 +210,7 @@ class MessageMethods(UploadMethods, MessageParseMethods):
# IDs are returned in descending order. # IDs are returned in descending order.
last_id = message.id last_id = message.id
yield_(custom.Message(self, message, entities, entity)) yield (custom.Message(self, message, entities, entity))
have += 1 have += 1
if len(r.messages) < request.limit: if len(r.messages) < request.limit:
@ -644,7 +641,6 @@ class MessageMethods(UploadMethods, MessageParseMethods):
# region Private methods # region Private methods
@async_generator
def _iter_ids(self, entity, ids, total): def _iter_ids(self, entity, ids, total):
""" """
Special case for `iter_messages` when it should only fetch some IDs. Special case for `iter_messages` when it should only fetch some IDs.
@ -662,7 +658,7 @@ class MessageMethods(UploadMethods, MessageParseMethods):
if isinstance(r, types.messages.MessagesNotModified): if isinstance(r, types.messages.MessagesNotModified):
for _ in ids: for _ in ids:
yield_(None) yield (None)
return return
entities = {utils.get_peer_id(x): x entities = {utils.get_peer_id(x): x
@ -677,8 +673,8 @@ class MessageMethods(UploadMethods, MessageParseMethods):
for message in r.messages: for message in r.messages:
if isinstance(message, types.MessageEmpty) or ( if isinstance(message, types.MessageEmpty) or (
from_id and utils.get_peer_id(message.to_id) != from_id): from_id and utils.get_peer_id(message.to_id) != from_id):
yield_(None) yield (None)
else: else:
yield_(custom.Message(self, message, entities, entity)) yield (custom.Message(self, message, entities, entity))
# endregion # endregion