mirror of
https://github.com/LonamiWebs/Telethon.git
synced 2025-02-10 16:40:57 +03:00
Merge pull request #1115 from LonamiWebs/requestiter
Overhaul asynchronous generators
This commit is contained in:
commit
c99157ade2
|
@ -1,3 +1,2 @@
|
||||||
pyaes
|
pyaes
|
||||||
rsa
|
rsa
|
||||||
async_generator
|
|
||||||
|
|
3
setup.py
3
setup.py
|
@ -214,8 +214,7 @@ def main():
|
||||||
packages=find_packages(exclude=[
|
packages=find_packages(exclude=[
|
||||||
'telethon_*', 'run_tests.py', 'try_telethon.py'
|
'telethon_*', 'run_tests.py', 'try_telethon.py'
|
||||||
]),
|
]),
|
||||||
install_requires=['pyaes', 'rsa',
|
install_requires=['pyaes', 'rsa'],
|
||||||
'async_generator'],
|
|
||||||
extras_require={
|
extras_require={
|
||||||
'cryptg': ['cryptg']
|
'cryptg': ['cryptg']
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,19 +1,192 @@
|
||||||
import itertools
|
import itertools
|
||||||
import sys
|
|
||||||
|
|
||||||
from async_generator import async_generator, yield_
|
|
||||||
|
|
||||||
from .users import UserMethods
|
from .users import UserMethods
|
||||||
from .. import utils, helpers
|
from .. import utils
|
||||||
|
from ..requestiter import RequestIter
|
||||||
from ..tl import types, functions, custom
|
from ..tl import types, functions, custom
|
||||||
|
|
||||||
|
|
||||||
|
class _ParticipantsIter(RequestIter):
|
||||||
|
async def _init(self, entity, filter, search, aggressive):
|
||||||
|
if isinstance(filter, type):
|
||||||
|
if filter in (types.ChannelParticipantsBanned,
|
||||||
|
types.ChannelParticipantsKicked,
|
||||||
|
types.ChannelParticipantsSearch,
|
||||||
|
types.ChannelParticipantsContacts):
|
||||||
|
# These require a `q` parameter (support types for convenience)
|
||||||
|
filter = filter('')
|
||||||
|
else:
|
||||||
|
filter = filter()
|
||||||
|
|
||||||
|
entity = await self.client.get_input_entity(entity)
|
||||||
|
if search and (filter
|
||||||
|
or not isinstance(entity, types.InputPeerChannel)):
|
||||||
|
# We need to 'search' ourselves unless we have a PeerChannel
|
||||||
|
search = search.lower()
|
||||||
|
|
||||||
|
self.filter_entity = lambda ent: (
|
||||||
|
search in utils.get_display_name(ent).lower() or
|
||||||
|
search in (getattr(ent, 'username', '') or None).lower()
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.filter_entity = lambda ent: True
|
||||||
|
|
||||||
|
if isinstance(entity, types.InputPeerChannel):
|
||||||
|
self.total = (await self.client(
|
||||||
|
functions.channels.GetFullChannelRequest(entity)
|
||||||
|
)).full_chat.participants_count
|
||||||
|
|
||||||
|
if self.limit == 0:
|
||||||
|
raise StopAsyncIteration
|
||||||
|
|
||||||
|
self.seen = set()
|
||||||
|
if aggressive and not filter:
|
||||||
|
self.requests = [functions.channels.GetParticipantsRequest(
|
||||||
|
channel=entity,
|
||||||
|
filter=types.ChannelParticipantsSearch(x),
|
||||||
|
offset=0,
|
||||||
|
limit=200,
|
||||||
|
hash=0
|
||||||
|
) for x in (search or map(chr, range(ord('a'), ord('z') + 1)))]
|
||||||
|
else:
|
||||||
|
self.requests = [functions.channels.GetParticipantsRequest(
|
||||||
|
channel=entity,
|
||||||
|
filter=filter or types.ChannelParticipantsSearch(search),
|
||||||
|
offset=0,
|
||||||
|
limit=200,
|
||||||
|
hash=0
|
||||||
|
)]
|
||||||
|
|
||||||
|
elif isinstance(entity, types.InputPeerChat):
|
||||||
|
full = await self.client(
|
||||||
|
functions.messages.GetFullChatRequest(entity.chat_id))
|
||||||
|
if not isinstance(
|
||||||
|
full.full_chat.participants, types.ChatParticipants):
|
||||||
|
# ChatParticipantsForbidden won't have ``.participants``
|
||||||
|
self.total = 0
|
||||||
|
raise StopAsyncIteration
|
||||||
|
|
||||||
|
self.total = len(full.full_chat.participants.participants)
|
||||||
|
|
||||||
|
users = {user.id: user for user in full.users}
|
||||||
|
for participant in full.full_chat.participants.participants:
|
||||||
|
user = users[participant.user_id]
|
||||||
|
if not self.filter_entity(user):
|
||||||
|
continue
|
||||||
|
|
||||||
|
user = users[participant.user_id]
|
||||||
|
user.participant = participant
|
||||||
|
self.buffer.append(user)
|
||||||
|
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
self.total = 1
|
||||||
|
if self.limit != 0:
|
||||||
|
user = await self.client.get_entity(entity)
|
||||||
|
if self.filter_entity(user):
|
||||||
|
user.participant = None
|
||||||
|
self.buffer.append(user)
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
async def _load_next_chunk(self):
|
||||||
|
if not self.requests:
|
||||||
|
return True
|
||||||
|
|
||||||
|
# Only care about the limit for the first request
|
||||||
|
# (small amount of people, won't be aggressive).
|
||||||
|
#
|
||||||
|
# Most people won't care about getting exactly 12,345
|
||||||
|
# members so it doesn't really matter not to be 100%
|
||||||
|
# precise with being out of the offset/limit here.
|
||||||
|
self.requests[0].limit = min(self.limit - self.requests[0].offset, 200)
|
||||||
|
if self.requests[0].offset > self.limit:
|
||||||
|
return True
|
||||||
|
|
||||||
|
results = await self.client(self.requests)
|
||||||
|
for i in reversed(range(len(self.requests))):
|
||||||
|
participants = results[i]
|
||||||
|
if not participants.users:
|
||||||
|
self.requests.pop(i)
|
||||||
|
continue
|
||||||
|
|
||||||
|
self.requests[i].offset += len(participants.participants)
|
||||||
|
users = {user.id: user for user in participants.users}
|
||||||
|
for participant in participants.participants:
|
||||||
|
user = users[participant.user_id]
|
||||||
|
if not self.filter_entity(user) or user.id in self.seen:
|
||||||
|
continue
|
||||||
|
|
||||||
|
self.seen.add(participant.user_id)
|
||||||
|
user = users[participant.user_id]
|
||||||
|
user.participant = participant
|
||||||
|
self.buffer.append(user)
|
||||||
|
|
||||||
|
|
||||||
|
class _AdminLogIter(RequestIter):
|
||||||
|
async def _init(
|
||||||
|
self, entity, admins, search, min_id, max_id,
|
||||||
|
join, leave, invite, restrict, unrestrict, ban, unban,
|
||||||
|
promote, demote, info, settings, pinned, edit, delete
|
||||||
|
):
|
||||||
|
if any((join, leave, invite, restrict, unrestrict, ban, unban,
|
||||||
|
promote, demote, info, settings, pinned, edit, delete)):
|
||||||
|
events_filter = types.ChannelAdminLogEventsFilter(
|
||||||
|
join=join, leave=leave, invite=invite, ban=restrict,
|
||||||
|
unban=unrestrict, kick=ban, unkick=unban, promote=promote,
|
||||||
|
demote=demote, info=info, settings=settings, pinned=pinned,
|
||||||
|
edit=edit, delete=delete
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
events_filter = None
|
||||||
|
|
||||||
|
self.entity = await self.client.get_input_entity(entity)
|
||||||
|
|
||||||
|
admin_list = []
|
||||||
|
if admins:
|
||||||
|
if not utils.is_list_like(admins):
|
||||||
|
admins = (admins,)
|
||||||
|
|
||||||
|
for admin in admins:
|
||||||
|
admin_list.append(await self.client.get_input_entity(admin))
|
||||||
|
|
||||||
|
self.request = functions.channels.GetAdminLogRequest(
|
||||||
|
self.entity, q=search or '', min_id=min_id, max_id=max_id,
|
||||||
|
limit=0, events_filter=events_filter, admins=admin_list or None
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _load_next_chunk(self):
|
||||||
|
self.request.limit = min(self.left, 100)
|
||||||
|
r = await self.client(self.request)
|
||||||
|
entities = {utils.get_peer_id(x): x
|
||||||
|
for x in itertools.chain(r.users, r.chats)}
|
||||||
|
|
||||||
|
self.request.max_id = min((e.id for e in r.events), default=0)
|
||||||
|
for ev in r.events:
|
||||||
|
if isinstance(ev.action,
|
||||||
|
types.ChannelAdminLogEventActionEditMessage):
|
||||||
|
ev.action.prev_message._finish_init(
|
||||||
|
self.client, entities, self.entity)
|
||||||
|
|
||||||
|
ev.action.new_message._finish_init(
|
||||||
|
self.client, entities, self.entity)
|
||||||
|
|
||||||
|
elif isinstance(ev.action,
|
||||||
|
types.ChannelAdminLogEventActionDeleteMessage):
|
||||||
|
ev.action.message._finish_init(
|
||||||
|
self.client, entities, self.entity)
|
||||||
|
|
||||||
|
self.buffer.append(custom.AdminLogEvent(ev, entities))
|
||||||
|
|
||||||
|
if len(r.events) < self.request.limit:
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
class ChatMethods(UserMethods):
|
class ChatMethods(UserMethods):
|
||||||
|
|
||||||
# region Public methods
|
# region Public methods
|
||||||
|
|
||||||
@async_generator
|
def iter_participants(
|
||||||
async def iter_participants(
|
|
||||||
self, entity, limit=None, *, search='',
|
self, entity, limit=None, *, search='',
|
||||||
filter=None, aggressive=False, _total=None):
|
filter=None, aggressive=False, _total=None):
|
||||||
"""
|
"""
|
||||||
|
@ -62,138 +235,23 @@ class ChatMethods(UserMethods):
|
||||||
matched :tl:`ChannelParticipant` type for channels/megagroups
|
matched :tl:`ChannelParticipant` type for channels/megagroups
|
||||||
or :tl:`ChatParticipants` for normal chats.
|
or :tl:`ChatParticipants` for normal chats.
|
||||||
"""
|
"""
|
||||||
if isinstance(filter, type):
|
return _ParticipantsIter(
|
||||||
if filter in (types.ChannelParticipantsBanned,
|
self,
|
||||||
types.ChannelParticipantsKicked,
|
limit,
|
||||||
types.ChannelParticipantsSearch,
|
entity=entity,
|
||||||
types.ChannelParticipantsContacts):
|
filter=filter,
|
||||||
# These require a `q` parameter (support types for convenience)
|
search=search,
|
||||||
filter = filter('')
|
aggressive=aggressive
|
||||||
else:
|
)
|
||||||
filter = filter()
|
|
||||||
|
|
||||||
entity = await self.get_input_entity(entity)
|
|
||||||
if search and (filter
|
|
||||||
or not isinstance(entity, types.InputPeerChannel)):
|
|
||||||
# We need to 'search' ourselves unless we have a PeerChannel
|
|
||||||
search = search.lower()
|
|
||||||
|
|
||||||
def filter_entity(ent):
|
|
||||||
return search in utils.get_display_name(ent).lower() or\
|
|
||||||
search in (getattr(ent, 'username', '') or None).lower()
|
|
||||||
else:
|
|
||||||
def filter_entity(ent):
|
|
||||||
return True
|
|
||||||
|
|
||||||
limit = float('inf') if limit is None else int(limit)
|
|
||||||
if isinstance(entity, types.InputPeerChannel):
|
|
||||||
if _total:
|
|
||||||
_total[0] = (await self(
|
|
||||||
functions.channels.GetFullChannelRequest(entity)
|
|
||||||
)).full_chat.participants_count
|
|
||||||
|
|
||||||
if limit == 0:
|
|
||||||
return
|
|
||||||
|
|
||||||
seen = set()
|
|
||||||
if aggressive and not filter:
|
|
||||||
requests = [functions.channels.GetParticipantsRequest(
|
|
||||||
channel=entity,
|
|
||||||
filter=types.ChannelParticipantsSearch(x),
|
|
||||||
offset=0,
|
|
||||||
limit=200,
|
|
||||||
hash=0
|
|
||||||
) for x in (search or map(chr, range(ord('a'), ord('z') + 1)))]
|
|
||||||
else:
|
|
||||||
requests = [functions.channels.GetParticipantsRequest(
|
|
||||||
channel=entity,
|
|
||||||
filter=filter or types.ChannelParticipantsSearch(search),
|
|
||||||
offset=0,
|
|
||||||
limit=200,
|
|
||||||
hash=0
|
|
||||||
)]
|
|
||||||
|
|
||||||
while requests:
|
|
||||||
# Only care about the limit for the first request
|
|
||||||
# (small amount of people, won't be aggressive).
|
|
||||||
#
|
|
||||||
# Most people won't care about getting exactly 12,345
|
|
||||||
# members so it doesn't really matter not to be 100%
|
|
||||||
# precise with being out of the offset/limit here.
|
|
||||||
requests[0].limit = min(limit - requests[0].offset, 200)
|
|
||||||
if requests[0].offset > limit:
|
|
||||||
break
|
|
||||||
|
|
||||||
results = await self(requests)
|
|
||||||
for i in reversed(range(len(requests))):
|
|
||||||
participants = results[i]
|
|
||||||
if not participants.users:
|
|
||||||
requests.pop(i)
|
|
||||||
else:
|
|
||||||
requests[i].offset += len(participants.participants)
|
|
||||||
users = {user.id: user for user in participants.users}
|
|
||||||
for participant in participants.participants:
|
|
||||||
user = users[participant.user_id]
|
|
||||||
if not filter_entity(user) or user.id in seen:
|
|
||||||
continue
|
|
||||||
|
|
||||||
seen.add(participant.user_id)
|
|
||||||
user = users[participant.user_id]
|
|
||||||
user.participant = participant
|
|
||||||
await yield_(user)
|
|
||||||
if len(seen) >= limit:
|
|
||||||
return
|
|
||||||
|
|
||||||
elif isinstance(entity, types.InputPeerChat):
|
|
||||||
full = await self(
|
|
||||||
functions.messages.GetFullChatRequest(entity.chat_id))
|
|
||||||
if not isinstance(
|
|
||||||
full.full_chat.participants, types.ChatParticipants):
|
|
||||||
# ChatParticipantsForbidden won't have ``.participants``
|
|
||||||
if _total:
|
|
||||||
_total[0] = 0
|
|
||||||
return
|
|
||||||
|
|
||||||
if _total:
|
|
||||||
_total[0] = len(full.full_chat.participants.participants)
|
|
||||||
|
|
||||||
have = 0
|
|
||||||
users = {user.id: user for user in full.users}
|
|
||||||
for participant in full.full_chat.participants.participants:
|
|
||||||
user = users[participant.user_id]
|
|
||||||
if not filter_entity(user):
|
|
||||||
continue
|
|
||||||
have += 1
|
|
||||||
if have > limit:
|
|
||||||
break
|
|
||||||
else:
|
|
||||||
user = users[participant.user_id]
|
|
||||||
user.participant = participant
|
|
||||||
await yield_(user)
|
|
||||||
else:
|
|
||||||
if _total:
|
|
||||||
_total[0] = 1
|
|
||||||
if limit != 0:
|
|
||||||
user = await self.get_entity(entity)
|
|
||||||
if filter_entity(user):
|
|
||||||
user.participant = None
|
|
||||||
await yield_(user)
|
|
||||||
|
|
||||||
async def get_participants(self, *args, **kwargs):
|
async def get_participants(self, *args, **kwargs):
|
||||||
"""
|
"""
|
||||||
Same as `iter_participants`, but returns a
|
Same as `iter_participants`, but returns a
|
||||||
`TotalList <telethon.helpers.TotalList>` instead.
|
`TotalList <telethon.helpers.TotalList>` instead.
|
||||||
"""
|
"""
|
||||||
total = [0]
|
return await self.iter_participants(*args, **kwargs).collect()
|
||||||
kwargs['_total'] = total
|
|
||||||
participants = helpers.TotalList()
|
|
||||||
async for x in self.iter_participants(*args, **kwargs):
|
|
||||||
participants.append(x)
|
|
||||||
participants.total = total[0]
|
|
||||||
return participants
|
|
||||||
|
|
||||||
@async_generator
|
def iter_admin_log(
|
||||||
async def iter_admin_log(
|
|
||||||
self, entity, limit=None, *, max_id=0, min_id=0, search=None,
|
self, entity, limit=None, *, max_id=0, min_id=0, search=None,
|
||||||
admins=None, join=None, leave=None, invite=None, restrict=None,
|
admins=None, join=None, leave=None, invite=None, restrict=None,
|
||||||
unrestrict=None, ban=None, unban=None, promote=None, demote=None,
|
unrestrict=None, ban=None, unban=None, promote=None, demote=None,
|
||||||
|
@ -285,66 +343,34 @@ class ChatMethods(UserMethods):
|
||||||
Yields:
|
Yields:
|
||||||
Instances of `telethon.tl.custom.adminlogevent.AdminLogEvent`.
|
Instances of `telethon.tl.custom.adminlogevent.AdminLogEvent`.
|
||||||
"""
|
"""
|
||||||
if limit is None:
|
return _AdminLogIter(
|
||||||
limit = sys.maxsize
|
self,
|
||||||
elif limit <= 0:
|
limit,
|
||||||
return
|
entity=entity,
|
||||||
|
admins=admins,
|
||||||
if any((join, leave, invite, restrict, unrestrict, ban, unban,
|
search=search,
|
||||||
promote, demote, info, settings, pinned, edit, delete)):
|
min_id=min_id,
|
||||||
events_filter = types.ChannelAdminLogEventsFilter(
|
max_id=max_id,
|
||||||
join=join, leave=leave, invite=invite, ban=restrict,
|
join=join,
|
||||||
unban=unrestrict, kick=ban, unkick=unban, promote=promote,
|
leave=leave,
|
||||||
demote=demote, info=info, settings=settings, pinned=pinned,
|
invite=invite,
|
||||||
edit=edit, delete=delete
|
restrict=restrict,
|
||||||
)
|
unrestrict=unrestrict,
|
||||||
else:
|
ban=ban,
|
||||||
events_filter = None
|
unban=unban,
|
||||||
|
promote=promote,
|
||||||
entity = await self.get_input_entity(entity)
|
demote=demote,
|
||||||
|
info=info,
|
||||||
admin_list = []
|
settings=settings,
|
||||||
if admins:
|
pinned=pinned,
|
||||||
if not utils.is_list_like(admins):
|
edit=edit,
|
||||||
admins = (admins,)
|
delete=delete
|
||||||
|
|
||||||
for admin in admins:
|
|
||||||
admin_list.append(await self.get_input_entity(admin))
|
|
||||||
|
|
||||||
request = functions.channels.GetAdminLogRequest(
|
|
||||||
entity, q=search or '', min_id=min_id, max_id=max_id,
|
|
||||||
limit=0, events_filter=events_filter, admins=admin_list or None
|
|
||||||
)
|
)
|
||||||
while limit > 0:
|
|
||||||
request.limit = min(limit, 100)
|
|
||||||
result = await self(request)
|
|
||||||
limit -= len(result.events)
|
|
||||||
entities = {utils.get_peer_id(x): x
|
|
||||||
for x in itertools.chain(result.users, result.chats)}
|
|
||||||
|
|
||||||
request.max_id = min((e.id for e in result.events), default=0)
|
|
||||||
for ev in result.events:
|
|
||||||
if isinstance(ev.action,
|
|
||||||
types.ChannelAdminLogEventActionEditMessage):
|
|
||||||
ev.action.prev_message._finish_init(self, entities, entity)
|
|
||||||
ev.action.new_message._finish_init(self, entities, entity)
|
|
||||||
|
|
||||||
elif isinstance(ev.action,
|
|
||||||
types.ChannelAdminLogEventActionDeleteMessage):
|
|
||||||
ev.action.message._finish_init(self, entities, entity)
|
|
||||||
|
|
||||||
await yield_(custom.AdminLogEvent(ev, entities))
|
|
||||||
|
|
||||||
if len(result.events) < request.limit:
|
|
||||||
break
|
|
||||||
|
|
||||||
async def get_admin_log(self, *args, **kwargs):
|
async def get_admin_log(self, *args, **kwargs):
|
||||||
"""
|
"""
|
||||||
Same as `iter_admin_log`, but returns a ``list`` instead.
|
Same as `iter_admin_log`, but returns a ``list`` instead.
|
||||||
"""
|
"""
|
||||||
admin_log = []
|
return await self.iter_admin_log(*args, **kwargs).collect()
|
||||||
async for x in self.iter_admin_log(*args, **kwargs):
|
|
||||||
admin_log.append(x)
|
|
||||||
return admin_log
|
|
||||||
|
|
||||||
# endregion
|
# endregion
|
||||||
|
|
|
@ -1,18 +1,101 @@
|
||||||
import itertools
|
import itertools
|
||||||
|
|
||||||
from async_generator import async_generator, yield_
|
|
||||||
|
|
||||||
from .users import UserMethods
|
from .users import UserMethods
|
||||||
from .. import utils, helpers
|
from .. import utils
|
||||||
|
from ..requestiter import RequestIter
|
||||||
from ..tl import types, functions, custom
|
from ..tl import types, functions, custom
|
||||||
|
|
||||||
|
|
||||||
|
class _DialogsIter(RequestIter):
|
||||||
|
async def _init(
|
||||||
|
self, offset_date, offset_id, offset_peer, ignore_migrated
|
||||||
|
):
|
||||||
|
self.request = functions.messages.GetDialogsRequest(
|
||||||
|
offset_date=offset_date,
|
||||||
|
offset_id=offset_id,
|
||||||
|
offset_peer=offset_peer,
|
||||||
|
limit=1,
|
||||||
|
hash=0
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.limit == 0:
|
||||||
|
# Special case, get a single dialog and determine count
|
||||||
|
dialogs = await self.client(self.request)
|
||||||
|
self.total = getattr(dialogs, 'count', len(dialogs.dialogs))
|
||||||
|
raise StopAsyncIteration
|
||||||
|
|
||||||
|
self.seen = set()
|
||||||
|
self.offset_date = offset_date
|
||||||
|
self.ignore_migrated = ignore_migrated
|
||||||
|
|
||||||
|
async def _load_next_chunk(self):
|
||||||
|
self.request.limit = min(self.left, 100)
|
||||||
|
r = await self.client(self.request)
|
||||||
|
|
||||||
|
self.total = getattr(r, 'count', len(r.dialogs))
|
||||||
|
|
||||||
|
entities = {utils.get_peer_id(x): x
|
||||||
|
for x in itertools.chain(r.users, r.chats)}
|
||||||
|
|
||||||
|
messages = {}
|
||||||
|
for m in r.messages:
|
||||||
|
m._finish_init(self, entities, None)
|
||||||
|
messages[m.id] = m
|
||||||
|
|
||||||
|
for d in r.dialogs:
|
||||||
|
# We check the offset date here because Telegram may ignore it
|
||||||
|
if self.offset_date:
|
||||||
|
date = getattr(messages.get(
|
||||||
|
d.top_message, None), 'date', None)
|
||||||
|
|
||||||
|
if not date or date.timestamp() > self.offset_date.timestamp():
|
||||||
|
continue
|
||||||
|
|
||||||
|
peer_id = utils.get_peer_id(d.peer)
|
||||||
|
if peer_id not in self.seen:
|
||||||
|
self.seen.add(peer_id)
|
||||||
|
cd = custom.Dialog(self, d, entities, messages)
|
||||||
|
if cd.dialog.pts:
|
||||||
|
self.client._channel_pts[cd.id] = cd.dialog.pts
|
||||||
|
|
||||||
|
if not self.ignore_migrated or getattr(
|
||||||
|
cd.entity, 'migrated_to', None) is None:
|
||||||
|
self.buffer.append(cd)
|
||||||
|
|
||||||
|
if len(r.dialogs) < self.request.limit\
|
||||||
|
or not isinstance(r, types.messages.DialogsSlice):
|
||||||
|
# Less than we requested means we reached the end, or
|
||||||
|
# we didn't get a DialogsSlice which means we got all.
|
||||||
|
return True
|
||||||
|
|
||||||
|
if self.request.offset_id == r.messages[-1].id:
|
||||||
|
# In some very rare cases this will get stuck in an infinite
|
||||||
|
# loop, where the offsets will get reused over and over. If
|
||||||
|
# the new offset is the same as the one before, break already.
|
||||||
|
return True
|
||||||
|
|
||||||
|
self.request.offset_id = r.messages[-1].id
|
||||||
|
self.request.exclude_pinned = True
|
||||||
|
self.request.offset_date = r.messages[-1].date
|
||||||
|
self.request.offset_peer =\
|
||||||
|
entities[utils.get_peer_id(r.dialogs[-1].peer)]
|
||||||
|
|
||||||
|
|
||||||
|
class _DraftsIter(RequestIter):
|
||||||
|
async def _init(self, **kwargs):
|
||||||
|
r = await self.client(functions.messages.GetAllDraftsRequest())
|
||||||
|
self.buffer.extend(custom.Draft._from_update(self.client, u)
|
||||||
|
for u in r.updates)
|
||||||
|
|
||||||
|
async def _load_next_chunk(self):
|
||||||
|
return []
|
||||||
|
|
||||||
|
|
||||||
class DialogMethods(UserMethods):
|
class DialogMethods(UserMethods):
|
||||||
|
|
||||||
# region Public methods
|
# region Public methods
|
||||||
|
|
||||||
@async_generator
|
def iter_dialogs(
|
||||||
async def iter_dialogs(
|
|
||||||
self, limit=None, *, offset_date=None, offset_id=0,
|
self, limit=None, *, offset_date=None, offset_id=0,
|
||||||
offset_peer=types.InputPeerEmpty(), ignore_migrated=False,
|
offset_peer=types.InputPeerEmpty(), ignore_migrated=False,
|
||||||
_total=None):
|
_total=None):
|
||||||
|
@ -50,99 +133,23 @@ class DialogMethods(UserMethods):
|
||||||
Yields:
|
Yields:
|
||||||
Instances of `telethon.tl.custom.dialog.Dialog`.
|
Instances of `telethon.tl.custom.dialog.Dialog`.
|
||||||
"""
|
"""
|
||||||
limit = float('inf') if limit is None else int(limit)
|
return _DialogsIter(
|
||||||
if limit == 0:
|
self,
|
||||||
if not _total:
|
limit,
|
||||||
return
|
|
||||||
# Special case, get a single dialog and determine count
|
|
||||||
dialogs = await self(functions.messages.GetDialogsRequest(
|
|
||||||
offset_date=offset_date,
|
|
||||||
offset_id=offset_id,
|
|
||||||
offset_peer=offset_peer,
|
|
||||||
limit=1,
|
|
||||||
hash=0
|
|
||||||
))
|
|
||||||
_total[0] = getattr(dialogs, 'count', len(dialogs.dialogs))
|
|
||||||
return
|
|
||||||
|
|
||||||
seen = set()
|
|
||||||
req = functions.messages.GetDialogsRequest(
|
|
||||||
offset_date=offset_date,
|
offset_date=offset_date,
|
||||||
offset_id=offset_id,
|
offset_id=offset_id,
|
||||||
offset_peer=offset_peer,
|
offset_peer=offset_peer,
|
||||||
limit=0,
|
ignore_migrated=ignore_migrated
|
||||||
hash=0
|
|
||||||
)
|
)
|
||||||
while len(seen) < limit:
|
|
||||||
req.limit = min(limit - len(seen), 100)
|
|
||||||
r = await self(req)
|
|
||||||
|
|
||||||
if _total:
|
|
||||||
_total[0] = getattr(r, 'count', len(r.dialogs))
|
|
||||||
|
|
||||||
entities = {utils.get_peer_id(x): x
|
|
||||||
for x in itertools.chain(r.users, r.chats)}
|
|
||||||
|
|
||||||
messages = {}
|
|
||||||
for m in r.messages:
|
|
||||||
m._finish_init(self, entities, None)
|
|
||||||
messages[m.id] = m
|
|
||||||
|
|
||||||
# Happens when there are pinned dialogs
|
|
||||||
if len(r.dialogs) > limit:
|
|
||||||
r.dialogs = r.dialogs[:limit]
|
|
||||||
|
|
||||||
for d in r.dialogs:
|
|
||||||
if offset_date:
|
|
||||||
date = getattr(messages.get(
|
|
||||||
d.top_message, None), 'date', None)
|
|
||||||
|
|
||||||
if not date or date.timestamp() > offset_date.timestamp():
|
|
||||||
continue
|
|
||||||
|
|
||||||
peer_id = utils.get_peer_id(d.peer)
|
|
||||||
if peer_id not in seen:
|
|
||||||
seen.add(peer_id)
|
|
||||||
cd = custom.Dialog(self, d, entities, messages)
|
|
||||||
if cd.dialog.pts:
|
|
||||||
self._channel_pts[cd.id] = cd.dialog.pts
|
|
||||||
|
|
||||||
if not ignore_migrated or getattr(
|
|
||||||
cd.entity, 'migrated_to', None) is None:
|
|
||||||
await yield_(cd)
|
|
||||||
|
|
||||||
if len(r.dialogs) < req.limit\
|
|
||||||
or not isinstance(r, types.messages.DialogsSlice):
|
|
||||||
# Less than we requested means we reached the end, or
|
|
||||||
# we didn't get a DialogsSlice which means we got all.
|
|
||||||
break
|
|
||||||
|
|
||||||
req.offset_date = r.messages[-1].date
|
|
||||||
req.offset_peer = entities[utils.get_peer_id(r.dialogs[-1].peer)]
|
|
||||||
if req.offset_id == r.messages[-1].id:
|
|
||||||
# In some very rare cases this will get stuck in an infinite
|
|
||||||
# loop, where the offsets will get reused over and over. If
|
|
||||||
# the new offset is the same as the one before, break already.
|
|
||||||
break
|
|
||||||
|
|
||||||
req.offset_id = r.messages[-1].id
|
|
||||||
req.exclude_pinned = True
|
|
||||||
|
|
||||||
async def get_dialogs(self, *args, **kwargs):
|
async def get_dialogs(self, *args, **kwargs):
|
||||||
"""
|
"""
|
||||||
Same as `iter_dialogs`, but returns a
|
Same as `iter_dialogs`, but returns a
|
||||||
`TotalList <telethon.helpers.TotalList>` instead.
|
`TotalList <telethon.helpers.TotalList>` instead.
|
||||||
"""
|
"""
|
||||||
total = [0]
|
return await self.iter_dialogs(*args, **kwargs).collect()
|
||||||
kwargs['_total'] = total
|
|
||||||
dialogs = helpers.TotalList()
|
|
||||||
async for x in self.iter_dialogs(*args, **kwargs):
|
|
||||||
dialogs.append(x)
|
|
||||||
dialogs.total = total[0]
|
|
||||||
return dialogs
|
|
||||||
|
|
||||||
@async_generator
|
def iter_drafts(self):
|
||||||
async def iter_drafts(self):
|
|
||||||
"""
|
"""
|
||||||
Iterator over all open draft messages.
|
Iterator over all open draft messages.
|
||||||
|
|
||||||
|
@ -151,18 +158,14 @@ class DialogMethods(UserMethods):
|
||||||
to change the message or `telethon.tl.custom.draft.Draft.delete`
|
to change the message or `telethon.tl.custom.draft.Draft.delete`
|
||||||
among other things.
|
among other things.
|
||||||
"""
|
"""
|
||||||
r = await self(functions.messages.GetAllDraftsRequest())
|
# TODO Passing a limit here makes no sense
|
||||||
for update in r.updates:
|
return _DraftsIter(self, None)
|
||||||
await yield_(custom.Draft._from_update(self, update))
|
|
||||||
|
|
||||||
async def get_drafts(self):
|
async def get_drafts(self):
|
||||||
"""
|
"""
|
||||||
Same as :meth:`iter_drafts`, but returns a list instead.
|
Same as :meth:`iter_drafts`, but returns a list instead.
|
||||||
"""
|
"""
|
||||||
result = []
|
return await self.iter_drafts().collect()
|
||||||
async for x in self.iter_drafts():
|
|
||||||
result.append(x)
|
|
||||||
return result
|
|
||||||
|
|
||||||
def conversation(
|
def conversation(
|
||||||
self, entity,
|
self, entity,
|
||||||
|
|
|
@ -1,14 +1,282 @@
|
||||||
import asyncio
|
|
||||||
import itertools
|
import itertools
|
||||||
import time
|
|
||||||
|
|
||||||
from async_generator import async_generator, yield_
|
|
||||||
|
|
||||||
from .messageparse import MessageParseMethods
|
from .messageparse import MessageParseMethods
|
||||||
from .uploads import UploadMethods
|
from .uploads import UploadMethods
|
||||||
from .buttons import ButtonMethods
|
from .buttons import ButtonMethods
|
||||||
from .. import helpers, utils, errors
|
from .. import helpers, utils, errors
|
||||||
from ..tl import types, functions
|
from ..tl import types, functions
|
||||||
|
from ..requestiter import RequestIter
|
||||||
|
|
||||||
|
|
||||||
|
# TODO Maybe RequestIter could rather have the update offset here?
|
||||||
|
# Maybe init should return the request to be used and it be
|
||||||
|
# called automatically? And another method to just process it.
|
||||||
|
class _MessagesIter(RequestIter):
|
||||||
|
"""
|
||||||
|
Common factor for all requests that need to iterate over messages.
|
||||||
|
"""
|
||||||
|
async def _init(
|
||||||
|
self, entity, offset_id, min_id, max_id, from_user,
|
||||||
|
batch_size, offset_date, add_offset, filter, search
|
||||||
|
):
|
||||||
|
# Note that entity being ``None`` will perform a global search.
|
||||||
|
if entity:
|
||||||
|
self.entity = await self.client.get_input_entity(entity)
|
||||||
|
else:
|
||||||
|
self.entity = None
|
||||||
|
if self.reverse:
|
||||||
|
raise ValueError('Cannot reverse global search')
|
||||||
|
|
||||||
|
# Telegram doesn't like min_id/max_id. If these IDs are low enough
|
||||||
|
# (starting from last_id - 100), the request will return nothing.
|
||||||
|
#
|
||||||
|
# We can emulate their behaviour locally by setting offset = max_id
|
||||||
|
# and simply stopping once we hit a message with ID <= min_id.
|
||||||
|
if self.reverse:
|
||||||
|
offset_id = max(offset_id, min_id)
|
||||||
|
if offset_id and max_id:
|
||||||
|
if max_id - offset_id <= 1:
|
||||||
|
raise StopAsyncIteration
|
||||||
|
|
||||||
|
if not max_id:
|
||||||
|
max_id = float('inf')
|
||||||
|
else:
|
||||||
|
offset_id = max(offset_id, max_id)
|
||||||
|
if offset_id and min_id:
|
||||||
|
if offset_id - min_id <= 1:
|
||||||
|
raise StopAsyncIteration
|
||||||
|
|
||||||
|
if self.reverse:
|
||||||
|
if offset_id:
|
||||||
|
offset_id += 1
|
||||||
|
else:
|
||||||
|
offset_id = 1
|
||||||
|
|
||||||
|
if from_user:
|
||||||
|
from_user = await self.client.get_input_entity(from_user)
|
||||||
|
if not isinstance(from_user, (
|
||||||
|
types.InputPeerUser, types.InputPeerSelf)):
|
||||||
|
from_user = None # Ignore from_user unless it's a user
|
||||||
|
|
||||||
|
if from_user:
|
||||||
|
self.from_id = await self.client.get_peer_id(from_user)
|
||||||
|
else:
|
||||||
|
self.from_id = None
|
||||||
|
|
||||||
|
if not self.entity:
|
||||||
|
self.request = functions.messages.SearchGlobalRequest(
|
||||||
|
q=search or '',
|
||||||
|
offset_date=offset_date,
|
||||||
|
offset_peer=types.InputPeerEmpty(),
|
||||||
|
offset_id=offset_id,
|
||||||
|
limit=1
|
||||||
|
)
|
||||||
|
elif search is not None or filter or from_user:
|
||||||
|
if filter is None:
|
||||||
|
filter = types.InputMessagesFilterEmpty()
|
||||||
|
|
||||||
|
# Telegram completely ignores `from_id` in private chats
|
||||||
|
if isinstance(
|
||||||
|
self.entity, (types.InputPeerUser, types.InputPeerSelf)):
|
||||||
|
# Don't bother sending `from_user` (it's ignored anyway),
|
||||||
|
# but keep `from_id` defined above to check it locally.
|
||||||
|
from_user = None
|
||||||
|
else:
|
||||||
|
# Do send `from_user` to do the filtering server-side,
|
||||||
|
# and set `from_id` to None to avoid checking it locally.
|
||||||
|
self.from_id = None
|
||||||
|
|
||||||
|
self.request = functions.messages.SearchRequest(
|
||||||
|
peer=self.entity,
|
||||||
|
q=search or '',
|
||||||
|
filter=filter() if isinstance(filter, type) else filter,
|
||||||
|
min_date=None,
|
||||||
|
max_date=offset_date,
|
||||||
|
offset_id=offset_id,
|
||||||
|
add_offset=add_offset,
|
||||||
|
limit=0, # Search actually returns 0 items if we ask it to
|
||||||
|
max_id=0,
|
||||||
|
min_id=0,
|
||||||
|
hash=0,
|
||||||
|
from_id=from_user
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.request = functions.messages.GetHistoryRequest(
|
||||||
|
peer=self.entity,
|
||||||
|
limit=1,
|
||||||
|
offset_date=offset_date,
|
||||||
|
offset_id=offset_id,
|
||||||
|
min_id=0,
|
||||||
|
max_id=0,
|
||||||
|
add_offset=add_offset,
|
||||||
|
hash=0
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.limit == 0:
|
||||||
|
# No messages, but we still need to know the total message count
|
||||||
|
result = await self.client(self.request)
|
||||||
|
if isinstance(result, types.messages.MessagesNotModified):
|
||||||
|
self.total = result.count
|
||||||
|
else:
|
||||||
|
self.total = getattr(result, 'count', len(result.messages))
|
||||||
|
raise StopAsyncIteration
|
||||||
|
|
||||||
|
if self.wait_time is None:
|
||||||
|
self.wait_time = 1 if self.limit > 3000 else 0
|
||||||
|
|
||||||
|
# Telegram has a hard limit of 100.
|
||||||
|
# We don't need to fetch 100 if the limit is less.
|
||||||
|
self.batch_size = min(max(batch_size, 1), min(100, self.limit))
|
||||||
|
|
||||||
|
# When going in reverse we need an offset of `-limit`, but we
|
||||||
|
# also want to respect what the user passed, so add them together.
|
||||||
|
if self.reverse:
|
||||||
|
self.request.add_offset -= self.batch_size
|
||||||
|
|
||||||
|
self.add_offset = add_offset
|
||||||
|
self.max_id = max_id
|
||||||
|
self.min_id = min_id
|
||||||
|
self.last_id = 0 if self.reverse else float('inf')
|
||||||
|
|
||||||
|
async def _load_next_chunk(self):
|
||||||
|
self.request.limit = min(self.left, self.batch_size)
|
||||||
|
if self.reverse and self.request.limit != self.batch_size:
|
||||||
|
# Remember that we need -limit when going in reverse
|
||||||
|
self.request.add_offset = self.add_offset - self.request.limit
|
||||||
|
|
||||||
|
r = await self.client(self.request)
|
||||||
|
self.total = getattr(r, 'count', len(r.messages))
|
||||||
|
|
||||||
|
entities = {utils.get_peer_id(x): x
|
||||||
|
for x in itertools.chain(r.users, r.chats)}
|
||||||
|
|
||||||
|
messages = reversed(r.messages) if self.reverse else r.messages
|
||||||
|
for message in messages:
|
||||||
|
if (isinstance(message, types.MessageEmpty)
|
||||||
|
or self.from_id and message.from_id != self.from_id):
|
||||||
|
continue
|
||||||
|
|
||||||
|
if not self._message_in_range(message):
|
||||||
|
return True
|
||||||
|
|
||||||
|
# There has been reports that on bad connections this method
|
||||||
|
# was returning duplicated IDs sometimes. Using ``last_id``
|
||||||
|
# is an attempt to avoid these duplicates, since the message
|
||||||
|
# IDs are returned in descending order (or asc if reverse).
|
||||||
|
self.last_id = message.id
|
||||||
|
message._finish_init(self.client, entities, self.entity)
|
||||||
|
self.buffer.append(message)
|
||||||
|
|
||||||
|
if len(r.messages) < self.request.limit:
|
||||||
|
return True
|
||||||
|
|
||||||
|
# Get the last message that's not empty (in some rare cases
|
||||||
|
# it can happen that the last message is :tl:`MessageEmpty`)
|
||||||
|
if self.buffer:
|
||||||
|
self._update_offset(self.buffer[-1])
|
||||||
|
else:
|
||||||
|
# There are some cases where all the messages we get start
|
||||||
|
# being empty. This can happen on migrated mega-groups if
|
||||||
|
# the history was cleared, and we're using search. Telegram
|
||||||
|
# acts incredibly weird sometimes. Messages are returned but
|
||||||
|
# only "empty", not their contents. If this is the case we
|
||||||
|
# should just give up since there won't be any new Message.
|
||||||
|
return True
|
||||||
|
|
||||||
|
def _message_in_range(self, message):
|
||||||
|
"""
|
||||||
|
Determine whether the given message is in the range or
|
||||||
|
it should be ignored (and avoid loading more chunks).
|
||||||
|
"""
|
||||||
|
# No entity means message IDs between chats may vary
|
||||||
|
if self.entity:
|
||||||
|
if self.reverse:
|
||||||
|
if message.id <= self.last_id or message.id >= self.max_id:
|
||||||
|
return False
|
||||||
|
else:
|
||||||
|
if message.id >= self.last_id or message.id <= self.min_id:
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
def _update_offset(self, last_message):
|
||||||
|
"""
|
||||||
|
After making the request, update its offset with the last message.
|
||||||
|
"""
|
||||||
|
self.request.offset_id = last_message.id
|
||||||
|
if self.reverse:
|
||||||
|
# We want to skip the one we already have
|
||||||
|
self.request.offset_id += 1
|
||||||
|
|
||||||
|
if isinstance(self.request, functions.messages.SearchRequest):
|
||||||
|
# Unlike getHistory and searchGlobal that use *offset* date,
|
||||||
|
# this is *max* date. This means that doing a search in reverse
|
||||||
|
# will break it. Since it's not really needed once we're going
|
||||||
|
# (only for the first request), it's safe to just clear it off.
|
||||||
|
self.request.max_date = None
|
||||||
|
else:
|
||||||
|
# getHistory and searchGlobal call it offset_date
|
||||||
|
self.request.offset_date = last_message.date
|
||||||
|
|
||||||
|
if isinstance(self.request, functions.messages.SearchGlobalRequest):
|
||||||
|
self.request.offset_peer = last_message.input_chat
|
||||||
|
|
||||||
|
|
||||||
|
class _IDsIter(RequestIter):
|
||||||
|
async def _init(self, entity, ids):
|
||||||
|
# TODO We never actually split IDs in chunks, but maybe we should
|
||||||
|
if not utils.is_list_like(ids):
|
||||||
|
self.ids = [ids]
|
||||||
|
elif not ids:
|
||||||
|
raise StopAsyncIteration
|
||||||
|
elif self.reverse:
|
||||||
|
self.ids = list(reversed(ids))
|
||||||
|
else:
|
||||||
|
self.ids = ids
|
||||||
|
|
||||||
|
if entity:
|
||||||
|
entity = await self.client.get_input_entity(entity)
|
||||||
|
|
||||||
|
self.total = len(ids)
|
||||||
|
|
||||||
|
from_id = None # By default, no need to validate from_id
|
||||||
|
if isinstance(entity, (types.InputChannel, types.InputPeerChannel)):
|
||||||
|
try:
|
||||||
|
r = await self.client(
|
||||||
|
functions.channels.GetMessagesRequest(entity, ids))
|
||||||
|
except errors.MessageIdsEmptyError:
|
||||||
|
# All IDs were invalid, use a dummy result
|
||||||
|
r = types.messages.MessagesNotModified(len(ids))
|
||||||
|
else:
|
||||||
|
r = await self.client(functions.messages.GetMessagesRequest(ids))
|
||||||
|
if entity:
|
||||||
|
from_id = await self.client.get_peer_id(entity)
|
||||||
|
|
||||||
|
if isinstance(r, types.messages.MessagesNotModified):
|
||||||
|
self.buffer.extend(None for _ in ids)
|
||||||
|
return
|
||||||
|
|
||||||
|
entities = {utils.get_peer_id(x): x
|
||||||
|
for x in itertools.chain(r.users, r.chats)}
|
||||||
|
|
||||||
|
# Telegram seems to return the messages in the order in which
|
||||||
|
# we asked them for, so we don't need to check it ourselves,
|
||||||
|
# unless some messages were invalid in which case Telegram
|
||||||
|
# may decide to not send them at all.
|
||||||
|
#
|
||||||
|
# The passed message IDs may not belong to the desired entity
|
||||||
|
# since the user can enter arbitrary numbers which can belong to
|
||||||
|
# arbitrary chats. Validate these unless ``from_id is None``.
|
||||||
|
for message in r.messages:
|
||||||
|
if isinstance(message, types.MessageEmpty) or (
|
||||||
|
from_id and message.chat_id != from_id):
|
||||||
|
self.buffer.append(None)
|
||||||
|
else:
|
||||||
|
message._finish_init(self.client, entities, entity)
|
||||||
|
self.buffer.append(message)
|
||||||
|
|
||||||
|
async def _load_next_chunk(self):
|
||||||
|
return True # no next chunk, all done in init
|
||||||
|
|
||||||
|
|
||||||
class MessageMethods(UploadMethods, ButtonMethods, MessageParseMethods):
|
class MessageMethods(UploadMethods, ButtonMethods, MessageParseMethods):
|
||||||
|
@ -17,8 +285,7 @@ class MessageMethods(UploadMethods, ButtonMethods, MessageParseMethods):
|
||||||
|
|
||||||
# region Message retrieval
|
# region Message retrieval
|
||||||
|
|
||||||
@async_generator
|
def iter_messages(
|
||||||
async def iter_messages(
|
|
||||||
self, entity, limit=None, *, offset_date=None, offset_id=0,
|
self, entity, limit=None, *, offset_date=None, offset_id=0,
|
||||||
max_id=0, min_id=0, add_offset=0, search=None, filter=None,
|
max_id=0, min_id=0, add_offset=0, search=None, filter=None,
|
||||||
from_user=None, batch_size=100, wait_time=None, ids=None,
|
from_user=None, batch_size=100, wait_time=None, ids=None,
|
||||||
|
@ -133,208 +400,26 @@ class MessageMethods(UploadMethods, ButtonMethods, MessageParseMethods):
|
||||||
an higher limit, so you're free to set the ``batch_size`` that
|
an higher limit, so you're free to set the ``batch_size`` that
|
||||||
you think may be good.
|
you think may be good.
|
||||||
"""
|
"""
|
||||||
# Note that entity being ``None`` is intended to get messages by
|
|
||||||
# ID under no specific chat, and also to request a global search.
|
|
||||||
if entity:
|
|
||||||
entity = await self.get_input_entity(entity)
|
|
||||||
|
|
||||||
if ids:
|
if ids is not None:
|
||||||
if not utils.is_list_like(ids):
|
return _IDsIter(self, limit, entity=entity, ids=ids)
|
||||||
ids = (ids,)
|
|
||||||
if reverse:
|
|
||||||
ids = list(reversed(ids))
|
|
||||||
async for x in self._iter_ids(entity, ids, total=_total):
|
|
||||||
await yield_(x)
|
|
||||||
return
|
|
||||||
|
|
||||||
# Telegram doesn't like min_id/max_id. If these IDs are low enough
|
return _MessagesIter(
|
||||||
# (starting from last_id - 100), the request will return nothing.
|
client=self,
|
||||||
#
|
reverse=reverse,
|
||||||
# We can emulate their behaviour locally by setting offset = max_id
|
wait_time=wait_time,
|
||||||
# and simply stopping once we hit a message with ID <= min_id.
|
limit=limit,
|
||||||
if reverse:
|
entity=entity,
|
||||||
offset_id = max(offset_id, min_id)
|
offset_id=offset_id,
|
||||||
if offset_id and max_id:
|
min_id=min_id,
|
||||||
if max_id - offset_id <= 1:
|
max_id=max_id,
|
||||||
return
|
from_user=from_user,
|
||||||
|
batch_size=batch_size,
|
||||||
if not max_id:
|
offset_date=offset_date,
|
||||||
max_id = float('inf')
|
add_offset=add_offset,
|
||||||
else:
|
filter=filter,
|
||||||
offset_id = max(offset_id, max_id)
|
search=search
|
||||||
if offset_id and min_id:
|
)
|
||||||
if offset_id - min_id <= 1:
|
|
||||||
return
|
|
||||||
|
|
||||||
if reverse:
|
|
||||||
if offset_id:
|
|
||||||
offset_id += 1
|
|
||||||
else:
|
|
||||||
offset_id = 1
|
|
||||||
|
|
||||||
if from_user:
|
|
||||||
from_user = await self.get_input_entity(from_user)
|
|
||||||
if not isinstance(from_user, (
|
|
||||||
types.InputPeerUser, types.InputPeerSelf)):
|
|
||||||
from_user = None # Ignore from_user unless it's a user
|
|
||||||
|
|
||||||
from_id = (await self.get_peer_id(from_user)) if from_user else None
|
|
||||||
|
|
||||||
limit = float('inf') if limit is None else int(limit)
|
|
||||||
if not entity:
|
|
||||||
if reverse:
|
|
||||||
raise ValueError('Cannot reverse global search')
|
|
||||||
|
|
||||||
reverse = None
|
|
||||||
request = functions.messages.SearchGlobalRequest(
|
|
||||||
q=search or '',
|
|
||||||
offset_date=offset_date,
|
|
||||||
offset_peer=types.InputPeerEmpty(),
|
|
||||||
offset_id=offset_id,
|
|
||||||
limit=1
|
|
||||||
)
|
|
||||||
elif search is not None or filter or from_user:
|
|
||||||
if filter is None:
|
|
||||||
filter = types.InputMessagesFilterEmpty()
|
|
||||||
|
|
||||||
# Telegram completely ignores `from_id` in private chats
|
|
||||||
if isinstance(entity, (types.InputPeerUser, types.InputPeerSelf)):
|
|
||||||
# Don't bother sending `from_user` (it's ignored anyway),
|
|
||||||
# but keep `from_id` defined above to check it locally.
|
|
||||||
from_user = None
|
|
||||||
else:
|
|
||||||
# Do send `from_user` to do the filtering server-side,
|
|
||||||
# and set `from_id` to None to avoid checking it locally.
|
|
||||||
from_id = None
|
|
||||||
|
|
||||||
request = functions.messages.SearchRequest(
|
|
||||||
peer=entity,
|
|
||||||
q=search or '',
|
|
||||||
filter=filter() if isinstance(filter, type) else filter,
|
|
||||||
min_date=None,
|
|
||||||
max_date=offset_date,
|
|
||||||
offset_id=offset_id,
|
|
||||||
add_offset=add_offset,
|
|
||||||
limit=0, # Search actually returns 0 items if we ask it to
|
|
||||||
max_id=0,
|
|
||||||
min_id=0,
|
|
||||||
hash=0,
|
|
||||||
from_id=from_user
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
request = functions.messages.GetHistoryRequest(
|
|
||||||
peer=entity,
|
|
||||||
limit=1,
|
|
||||||
offset_date=offset_date,
|
|
||||||
offset_id=offset_id,
|
|
||||||
min_id=0,
|
|
||||||
max_id=0,
|
|
||||||
add_offset=add_offset,
|
|
||||||
hash=0
|
|
||||||
)
|
|
||||||
|
|
||||||
if limit == 0:
|
|
||||||
if not _total:
|
|
||||||
return
|
|
||||||
# No messages, but we still need to know the total message count
|
|
||||||
result = await self(request)
|
|
||||||
if isinstance(result, types.messages.MessagesNotModified):
|
|
||||||
_total[0] = result.count
|
|
||||||
else:
|
|
||||||
_total[0] = getattr(result, 'count', len(result.messages))
|
|
||||||
return
|
|
||||||
|
|
||||||
if wait_time is None:
|
|
||||||
wait_time = 1 if limit > 3000 else 0
|
|
||||||
|
|
||||||
have = 0
|
|
||||||
last_id = 0 if reverse else float('inf')
|
|
||||||
|
|
||||||
# Telegram has a hard limit of 100.
|
|
||||||
# We don't need to fetch 100 if the limit is less.
|
|
||||||
batch_size = min(max(batch_size, 1), min(100, limit))
|
|
||||||
|
|
||||||
# When going in reverse we need an offset of `-limit`, but we
|
|
||||||
# also want to respect what the user passed, so add them together.
|
|
||||||
if reverse:
|
|
||||||
request.add_offset -= batch_size
|
|
||||||
|
|
||||||
while have < limit:
|
|
||||||
start = time.time()
|
|
||||||
|
|
||||||
request.limit = min(limit - have, batch_size)
|
|
||||||
if reverse and request.limit != batch_size:
|
|
||||||
# Remember that we need -limit when going in reverse
|
|
||||||
request.add_offset = add_offset - request.limit
|
|
||||||
|
|
||||||
r = await self(request)
|
|
||||||
if _total:
|
|
||||||
_total[0] = getattr(r, 'count', len(r.messages))
|
|
||||||
|
|
||||||
entities = {utils.get_peer_id(x): x
|
|
||||||
for x in itertools.chain(r.users, r.chats)}
|
|
||||||
|
|
||||||
messages = reversed(r.messages) if reverse else r.messages
|
|
||||||
for message in messages:
|
|
||||||
if (isinstance(message, types.MessageEmpty)
|
|
||||||
or from_id and message.from_id != from_id):
|
|
||||||
continue
|
|
||||||
|
|
||||||
if reverse is None:
|
|
||||||
pass
|
|
||||||
elif reverse:
|
|
||||||
if message.id <= last_id or message.id >= max_id:
|
|
||||||
return
|
|
||||||
else:
|
|
||||||
if message.id >= last_id or message.id <= min_id:
|
|
||||||
return
|
|
||||||
|
|
||||||
# There has been reports that on bad connections this method
|
|
||||||
# was returning duplicated IDs sometimes. Using ``last_id``
|
|
||||||
# is an attempt to avoid these duplicates, since the message
|
|
||||||
# IDs are returned in descending order (or asc if reverse).
|
|
||||||
last_id = message.id
|
|
||||||
|
|
||||||
message._finish_init(self, entities, entity)
|
|
||||||
await yield_(message)
|
|
||||||
have += 1
|
|
||||||
|
|
||||||
if len(r.messages) < request.limit:
|
|
||||||
break
|
|
||||||
|
|
||||||
# Find the first message that's not empty (in some rare cases
|
|
||||||
# it can happen that the last message is :tl:`MessageEmpty`)
|
|
||||||
last_message = None
|
|
||||||
messages = r.messages if reverse else reversed(r.messages)
|
|
||||||
for m in messages:
|
|
||||||
if not isinstance(m, types.MessageEmpty):
|
|
||||||
last_message = m
|
|
||||||
break
|
|
||||||
|
|
||||||
if last_message is None:
|
|
||||||
# There are some cases where all the messages we get start
|
|
||||||
# being empty. This can happen on migrated mega-groups if
|
|
||||||
# the history was cleared, and we're using search. Telegram
|
|
||||||
# acts incredibly weird sometimes. Messages are returned but
|
|
||||||
# only "empty", not their contents. If this is the case we
|
|
||||||
# should just give up since there won't be any new Message.
|
|
||||||
break
|
|
||||||
else:
|
|
||||||
request.offset_id = last_message.id
|
|
||||||
if isinstance(request, functions.messages.SearchRequest):
|
|
||||||
request.max_date = last_message.date
|
|
||||||
else:
|
|
||||||
# getHistory and searchGlobal call it offset_date
|
|
||||||
request.offset_date = last_message.date
|
|
||||||
|
|
||||||
if isinstance(request, functions.messages.SearchGlobalRequest):
|
|
||||||
request.offset_peer = last_message.input_chat
|
|
||||||
elif reverse:
|
|
||||||
# We want to skip the one we already have
|
|
||||||
request.offset_id += 1
|
|
||||||
|
|
||||||
await asyncio.sleep(
|
|
||||||
max(wait_time - (time.time() - start), 0), loop=self._loop)
|
|
||||||
|
|
||||||
async def get_messages(self, *args, **kwargs):
|
async def get_messages(self, *args, **kwargs):
|
||||||
"""
|
"""
|
||||||
|
@ -353,23 +438,23 @@ class MessageMethods(UploadMethods, ButtonMethods, MessageParseMethods):
|
||||||
a single `Message <telethon.tl.custom.message.Message>` will be
|
a single `Message <telethon.tl.custom.message.Message>` will be
|
||||||
returned for convenience instead of a list.
|
returned for convenience instead of a list.
|
||||||
"""
|
"""
|
||||||
total = [0]
|
|
||||||
kwargs['_total'] = total
|
|
||||||
if len(args) == 1 and 'limit' not in kwargs:
|
if len(args) == 1 and 'limit' not in kwargs:
|
||||||
if 'min_id' in kwargs and 'max_id' in kwargs:
|
if 'min_id' in kwargs and 'max_id' in kwargs:
|
||||||
kwargs['limit'] = None
|
kwargs['limit'] = None
|
||||||
else:
|
else:
|
||||||
kwargs['limit'] = 1
|
kwargs['limit'] = 1
|
||||||
|
|
||||||
msgs = helpers.TotalList()
|
it = self.iter_messages(*args, **kwargs)
|
||||||
async for x in self.iter_messages(*args, **kwargs):
|
|
||||||
msgs.append(x)
|
|
||||||
msgs.total = total[0]
|
|
||||||
if 'ids' in kwargs and not utils.is_list_like(kwargs['ids']):
|
|
||||||
# Check for empty list to handle InputMessageReplyTo
|
|
||||||
return msgs[0] if msgs else None
|
|
||||||
|
|
||||||
return msgs
|
ids = kwargs.get('ids')
|
||||||
|
if ids and not utils.is_list_like(ids):
|
||||||
|
async for message in it:
|
||||||
|
return message
|
||||||
|
else:
|
||||||
|
# Iterator exhausted = empty, to handle InputMessageReplyTo
|
||||||
|
return None
|
||||||
|
|
||||||
|
return await it.collect()
|
||||||
|
|
||||||
# endregion
|
# endregion
|
||||||
|
|
||||||
|
@ -799,52 +884,3 @@ class MessageMethods(UploadMethods, ButtonMethods, MessageParseMethods):
|
||||||
# endregion
|
# endregion
|
||||||
|
|
||||||
# endregion
|
# endregion
|
||||||
|
|
||||||
# region Private methods
|
|
||||||
|
|
||||||
@async_generator
|
|
||||||
async def _iter_ids(self, entity, ids, total):
|
|
||||||
"""
|
|
||||||
Special case for `iter_messages` when it should only fetch some IDs.
|
|
||||||
"""
|
|
||||||
if total:
|
|
||||||
total[0] = len(ids)
|
|
||||||
|
|
||||||
from_id = None # By default, no need to validate from_id
|
|
||||||
if isinstance(entity, (types.InputChannel, types.InputPeerChannel)):
|
|
||||||
try:
|
|
||||||
r = await self(
|
|
||||||
functions.channels.GetMessagesRequest(entity, ids))
|
|
||||||
except errors.MessageIdsEmptyError:
|
|
||||||
# All IDs were invalid, use a dummy result
|
|
||||||
r = types.messages.MessagesNotModified(len(ids))
|
|
||||||
else:
|
|
||||||
r = await self(functions.messages.GetMessagesRequest(ids))
|
|
||||||
if entity:
|
|
||||||
from_id = utils.get_peer_id(entity)
|
|
||||||
|
|
||||||
if isinstance(r, types.messages.MessagesNotModified):
|
|
||||||
for _ in ids:
|
|
||||||
await yield_(None)
|
|
||||||
return
|
|
||||||
|
|
||||||
entities = {utils.get_peer_id(x): x
|
|
||||||
for x in itertools.chain(r.users, r.chats)}
|
|
||||||
|
|
||||||
# Telegram seems to return the messages in the order in which
|
|
||||||
# we asked them for, so we don't need to check it ourselves,
|
|
||||||
# unless some messages were invalid in which case Telegram
|
|
||||||
# may decide to not send them at all.
|
|
||||||
#
|
|
||||||
# The passed message IDs may not belong to the desired entity
|
|
||||||
# since the user can enter arbitrary numbers which can belong to
|
|
||||||
# arbitrary chats. Validate these unless ``from_id is None``.
|
|
||||||
for message in r.messages:
|
|
||||||
if isinstance(message, types.MessageEmpty) or (
|
|
||||||
from_id and message.chat_id != from_id):
|
|
||||||
await yield_(None)
|
|
||||||
else:
|
|
||||||
message._finish_init(self, entities, entity)
|
|
||||||
await yield_(message)
|
|
||||||
|
|
||||||
# endregion
|
|
||||||
|
|
130
telethon/requestiter.py
Normal file
130
telethon/requestiter.py
Normal file
|
@ -0,0 +1,130 @@
|
||||||
|
import abc
|
||||||
|
import asyncio
|
||||||
|
import time
|
||||||
|
|
||||||
|
from . import helpers
|
||||||
|
|
||||||
|
|
||||||
|
# TODO There are two types of iterators for requests.
|
||||||
|
# One has a limit of items to retrieve, and the
|
||||||
|
# other has a list that must be called in chunks.
|
||||||
|
# Make classes for both here so it's easy to use.
|
||||||
|
class RequestIter(abc.ABC):
|
||||||
|
"""
|
||||||
|
Helper class to deal with requests that need offsets to iterate.
|
||||||
|
|
||||||
|
It has some facilities, such as automatically sleeping a desired
|
||||||
|
amount of time between requests if needed (but not more).
|
||||||
|
|
||||||
|
Can be used synchronously if the event loop is not running and
|
||||||
|
as an asynchronous iterator otherwise.
|
||||||
|
|
||||||
|
`limit` is the total amount of items that the iterator should return.
|
||||||
|
This is handled on this base class, and will be always ``>= 0``.
|
||||||
|
|
||||||
|
`left` will be reset every time the iterator is used and will indicate
|
||||||
|
the amount of items that should be emitted left, so that subclasses can
|
||||||
|
be more efficient and fetch only as many items as they need.
|
||||||
|
|
||||||
|
Iterators may be used with ``reversed``, and their `reverse` flag will
|
||||||
|
be set to ``True`` if that's the case. Note that if this flag is set,
|
||||||
|
`buffer` should be filled in reverse too.
|
||||||
|
"""
|
||||||
|
def __init__(self, client, limit, *, reverse=False, wait_time=None, **kwargs):
|
||||||
|
self.client = client
|
||||||
|
self.reverse = reverse
|
||||||
|
self.wait_time = wait_time
|
||||||
|
self.kwargs = kwargs
|
||||||
|
self.limit = max(float('inf') if limit is None else limit, 0)
|
||||||
|
self.left = None
|
||||||
|
self.buffer = None
|
||||||
|
self.index = None
|
||||||
|
self.total = None
|
||||||
|
self.last_load = None
|
||||||
|
|
||||||
|
async def _init(self, **kwargs):
|
||||||
|
"""
|
||||||
|
Called when asynchronous initialization is necessary. All keyword
|
||||||
|
arguments passed to `__init__` will be forwarded here, and it's
|
||||||
|
preferable to use named arguments in the subclasses without defaults
|
||||||
|
to avoid forgetting or misspelling any of them.
|
||||||
|
|
||||||
|
This method may ``raise StopAsyncIteration`` if it cannot continue.
|
||||||
|
|
||||||
|
This method may actually fill the initial buffer if it needs to.
|
||||||
|
"""
|
||||||
|
|
||||||
|
async def __anext__(self):
|
||||||
|
if self.buffer is None:
|
||||||
|
self.buffer = []
|
||||||
|
await self._init(**self.kwargs)
|
||||||
|
|
||||||
|
if self.left <= 0: # <= 0 because subclasses may change it
|
||||||
|
raise StopAsyncIteration
|
||||||
|
|
||||||
|
if self.index == len(self.buffer):
|
||||||
|
# asyncio will handle times <= 0 to sleep 0 seconds
|
||||||
|
if self.wait_time:
|
||||||
|
await asyncio.sleep(
|
||||||
|
self.wait_time - (time.time() - self.last_load),
|
||||||
|
loop=self.client.loop
|
||||||
|
)
|
||||||
|
self.last_load = time.time()
|
||||||
|
|
||||||
|
self.index = 0
|
||||||
|
self.buffer = []
|
||||||
|
if await self._load_next_chunk():
|
||||||
|
self.left = len(self.buffer)
|
||||||
|
|
||||||
|
if not self.buffer:
|
||||||
|
raise StopAsyncIteration
|
||||||
|
|
||||||
|
result = self.buffer[self.index]
|
||||||
|
self.left -= 1
|
||||||
|
self.index += 1
|
||||||
|
return result
|
||||||
|
|
||||||
|
def __aiter__(self):
|
||||||
|
self.buffer = None
|
||||||
|
self.index = 0
|
||||||
|
self.last_load = 0
|
||||||
|
self.left = self.limit
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
if self.client.loop.is_running():
|
||||||
|
raise RuntimeError(
|
||||||
|
'You must use "async for" if the event loop '
|
||||||
|
'is running (i.e. you are inside an "async def")'
|
||||||
|
)
|
||||||
|
|
||||||
|
return self.__aiter__()
|
||||||
|
|
||||||
|
async def collect(self):
|
||||||
|
"""
|
||||||
|
Create a `self` iterator and collect it into a `TotalList`
|
||||||
|
(a normal list with a `.total` attribute).
|
||||||
|
"""
|
||||||
|
result = helpers.TotalList()
|
||||||
|
async for message in self:
|
||||||
|
result.append(message)
|
||||||
|
|
||||||
|
result.total = self.total
|
||||||
|
return result
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
async def _load_next_chunk(self):
|
||||||
|
"""
|
||||||
|
Called when the next chunk is necessary.
|
||||||
|
|
||||||
|
It should extend the `buffer` with new items.
|
||||||
|
|
||||||
|
It should return ``True`` if it's the last chunk,
|
||||||
|
after which moment the method won't be called again
|
||||||
|
during the same iteration.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def __reversed__(self):
|
||||||
|
self.reverse = not self.reverse
|
||||||
|
return self # __aiter__ will be called after, too
|
|
@ -14,8 +14,6 @@ import asyncio
|
||||||
import functools
|
import functools
|
||||||
import inspect
|
import inspect
|
||||||
|
|
||||||
from async_generator import isasyncgenfunction
|
|
||||||
|
|
||||||
from .client.telegramclient import TelegramClient
|
from .client.telegramclient import TelegramClient
|
||||||
from .tl.custom import (
|
from .tl.custom import (
|
||||||
Draft, Dialog, MessageButton, Forward, Message, InlineResult, Conversation
|
Draft, Dialog, MessageButton, Forward, Message, InlineResult, Conversation
|
||||||
|
@ -24,22 +22,7 @@ from .tl.custom.chatgetter import ChatGetter
|
||||||
from .tl.custom.sendergetter import SenderGetter
|
from .tl.custom.sendergetter import SenderGetter
|
||||||
|
|
||||||
|
|
||||||
class _SyncGen:
|
def _syncify_wrap(t, method_name):
|
||||||
def __init__(self, gen):
|
|
||||||
self.gen = gen
|
|
||||||
|
|
||||||
def __iter__(self):
|
|
||||||
return self
|
|
||||||
|
|
||||||
def __next__(self):
|
|
||||||
try:
|
|
||||||
return asyncio.get_event_loop() \
|
|
||||||
.run_until_complete(self.gen.__anext__())
|
|
||||||
except StopAsyncIteration:
|
|
||||||
raise StopIteration from None
|
|
||||||
|
|
||||||
|
|
||||||
def _syncify_wrap(t, method_name, gen):
|
|
||||||
method = getattr(t, method_name)
|
method = getattr(t, method_name)
|
||||||
|
|
||||||
@functools.wraps(method)
|
@functools.wraps(method)
|
||||||
|
@ -48,8 +31,6 @@ def _syncify_wrap(t, method_name, gen):
|
||||||
loop = asyncio.get_event_loop()
|
loop = asyncio.get_event_loop()
|
||||||
if loop.is_running():
|
if loop.is_running():
|
||||||
return coro
|
return coro
|
||||||
elif gen:
|
|
||||||
return _SyncGen(coro)
|
|
||||||
else:
|
else:
|
||||||
return loop.run_until_complete(coro)
|
return loop.run_until_complete(coro)
|
||||||
|
|
||||||
|
@ -64,13 +45,14 @@ def syncify(*types):
|
||||||
into synchronous, which return either the coroutine or the result
|
into synchronous, which return either the coroutine or the result
|
||||||
based on whether ``asyncio's`` event loop is running.
|
based on whether ``asyncio's`` event loop is running.
|
||||||
"""
|
"""
|
||||||
|
# Our asynchronous generators all are `RequestIter`, which already
|
||||||
|
# provide a synchronous iterator variant, so we don't need to worry
|
||||||
|
# about asyncgenfunction's here.
|
||||||
for t in types:
|
for t in types:
|
||||||
for name in dir(t):
|
for name in dir(t):
|
||||||
if not name.startswith('_') or name == '__call__':
|
if not name.startswith('_') or name == '__call__':
|
||||||
if inspect.iscoroutinefunction(getattr(t, name)):
|
if inspect.iscoroutinefunction(getattr(t, name)):
|
||||||
_syncify_wrap(t, name, gen=False)
|
_syncify_wrap(t, name)
|
||||||
elif isasyncgenfunction(getattr(t, name)):
|
|
||||||
_syncify_wrap(t, name, gen=True)
|
|
||||||
|
|
||||||
|
|
||||||
syncify(TelegramClient, Draft, Dialog, MessageButton,
|
syncify(TelegramClient, Draft, Dialog, MessageButton,
|
||||||
|
|
Loading…
Reference in New Issue
Block a user