mirror of
				https://github.com/LonamiWebs/Telethon.git
				synced 2025-11-04 01:47:27 +03:00 
			
		
		
		
	Merge pull request #1115 from LonamiWebs/requestiter
Overhaul asynchronous generators
This commit is contained in:
		
						commit
						c99157ade2
					
				| 
						 | 
					@ -1,3 +1,2 @@
 | 
				
			||||||
pyaes
 | 
					pyaes
 | 
				
			||||||
rsa
 | 
					rsa
 | 
				
			||||||
async_generator
 | 
					 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
							
								
								
									
										3
									
								
								setup.py
									
									
									
									
									
								
							
							
						
						
									
										3
									
								
								setup.py
									
									
									
									
									
								
							| 
						 | 
					@ -214,8 +214,7 @@ def main():
 | 
				
			||||||
            packages=find_packages(exclude=[
 | 
					            packages=find_packages(exclude=[
 | 
				
			||||||
                'telethon_*', 'run_tests.py', 'try_telethon.py'
 | 
					                'telethon_*', 'run_tests.py', 'try_telethon.py'
 | 
				
			||||||
            ]),
 | 
					            ]),
 | 
				
			||||||
            install_requires=['pyaes', 'rsa',
 | 
					            install_requires=['pyaes', 'rsa'],
 | 
				
			||||||
                              'async_generator'],
 | 
					 | 
				
			||||||
            extras_require={
 | 
					            extras_require={
 | 
				
			||||||
                'cryptg': ['cryptg']
 | 
					                'cryptg': ['cryptg']
 | 
				
			||||||
            }
 | 
					            }
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -1,19 +1,192 @@
 | 
				
			||||||
import itertools
 | 
					import itertools
 | 
				
			||||||
import sys
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
from async_generator import async_generator, yield_
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
from .users import UserMethods
 | 
					from .users import UserMethods
 | 
				
			||||||
from .. import utils, helpers
 | 
					from .. import utils
 | 
				
			||||||
 | 
					from ..requestiter import RequestIter
 | 
				
			||||||
from ..tl import types, functions, custom
 | 
					from ..tl import types, functions, custom
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class _ParticipantsIter(RequestIter):
 | 
				
			||||||
 | 
					    async def _init(self, entity, filter, search, aggressive):
 | 
				
			||||||
 | 
					        if isinstance(filter, type):
 | 
				
			||||||
 | 
					            if filter in (types.ChannelParticipantsBanned,
 | 
				
			||||||
 | 
					                          types.ChannelParticipantsKicked,
 | 
				
			||||||
 | 
					                          types.ChannelParticipantsSearch,
 | 
				
			||||||
 | 
					                          types.ChannelParticipantsContacts):
 | 
				
			||||||
 | 
					                # These require a `q` parameter (support types for convenience)
 | 
				
			||||||
 | 
					                filter = filter('')
 | 
				
			||||||
 | 
					            else:
 | 
				
			||||||
 | 
					                filter = filter()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        entity = await self.client.get_input_entity(entity)
 | 
				
			||||||
 | 
					        if search and (filter
 | 
				
			||||||
 | 
					                       or not isinstance(entity, types.InputPeerChannel)):
 | 
				
			||||||
 | 
					            # We need to 'search' ourselves unless we have a PeerChannel
 | 
				
			||||||
 | 
					            search = search.lower()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            self.filter_entity = lambda ent: (
 | 
				
			||||||
 | 
					                search in utils.get_display_name(ent).lower() or
 | 
				
			||||||
 | 
					                search in (getattr(ent, 'username', '') or None).lower()
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            self.filter_entity = lambda ent: True
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if isinstance(entity, types.InputPeerChannel):
 | 
				
			||||||
 | 
					            self.total = (await self.client(
 | 
				
			||||||
 | 
					                functions.channels.GetFullChannelRequest(entity)
 | 
				
			||||||
 | 
					            )).full_chat.participants_count
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            if self.limit == 0:
 | 
				
			||||||
 | 
					                raise StopAsyncIteration
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            self.seen = set()
 | 
				
			||||||
 | 
					            if aggressive and not filter:
 | 
				
			||||||
 | 
					                self.requests = [functions.channels.GetParticipantsRequest(
 | 
				
			||||||
 | 
					                    channel=entity,
 | 
				
			||||||
 | 
					                    filter=types.ChannelParticipantsSearch(x),
 | 
				
			||||||
 | 
					                    offset=0,
 | 
				
			||||||
 | 
					                    limit=200,
 | 
				
			||||||
 | 
					                    hash=0
 | 
				
			||||||
 | 
					                ) for x in (search or map(chr, range(ord('a'), ord('z') + 1)))]
 | 
				
			||||||
 | 
					            else:
 | 
				
			||||||
 | 
					                self.requests = [functions.channels.GetParticipantsRequest(
 | 
				
			||||||
 | 
					                    channel=entity,
 | 
				
			||||||
 | 
					                    filter=filter or types.ChannelParticipantsSearch(search),
 | 
				
			||||||
 | 
					                    offset=0,
 | 
				
			||||||
 | 
					                    limit=200,
 | 
				
			||||||
 | 
					                    hash=0
 | 
				
			||||||
 | 
					                )]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        elif isinstance(entity, types.InputPeerChat):
 | 
				
			||||||
 | 
					            full = await self.client(
 | 
				
			||||||
 | 
					                functions.messages.GetFullChatRequest(entity.chat_id))
 | 
				
			||||||
 | 
					            if not isinstance(
 | 
				
			||||||
 | 
					                    full.full_chat.participants, types.ChatParticipants):
 | 
				
			||||||
 | 
					                # ChatParticipantsForbidden won't have ``.participants``
 | 
				
			||||||
 | 
					                self.total = 0
 | 
				
			||||||
 | 
					                raise StopAsyncIteration
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            self.total = len(full.full_chat.participants.participants)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            users = {user.id: user for user in full.users}
 | 
				
			||||||
 | 
					            for participant in full.full_chat.participants.participants:
 | 
				
			||||||
 | 
					                user = users[participant.user_id]
 | 
				
			||||||
 | 
					                if not self.filter_entity(user):
 | 
				
			||||||
 | 
					                    continue
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                user = users[participant.user_id]
 | 
				
			||||||
 | 
					                user.participant = participant
 | 
				
			||||||
 | 
					                self.buffer.append(user)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            return True
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            self.total = 1
 | 
				
			||||||
 | 
					            if self.limit != 0:
 | 
				
			||||||
 | 
					                user = await self.client.get_entity(entity)
 | 
				
			||||||
 | 
					                if self.filter_entity(user):
 | 
				
			||||||
 | 
					                    user.participant = None
 | 
				
			||||||
 | 
					                    self.buffer.append(user)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            return True
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    async def _load_next_chunk(self):
 | 
				
			||||||
 | 
					        if not self.requests:
 | 
				
			||||||
 | 
					            return True
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # Only care about the limit for the first request
 | 
				
			||||||
 | 
					        # (small amount of people, won't be aggressive).
 | 
				
			||||||
 | 
					        #
 | 
				
			||||||
 | 
					        # Most people won't care about getting exactly 12,345
 | 
				
			||||||
 | 
					        # members so it doesn't really matter not to be 100%
 | 
				
			||||||
 | 
					        # precise with being out of the offset/limit here.
 | 
				
			||||||
 | 
					        self.requests[0].limit = min(self.limit - self.requests[0].offset, 200)
 | 
				
			||||||
 | 
					        if self.requests[0].offset > self.limit:
 | 
				
			||||||
 | 
					            return True
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        results = await self.client(self.requests)
 | 
				
			||||||
 | 
					        for i in reversed(range(len(self.requests))):
 | 
				
			||||||
 | 
					            participants = results[i]
 | 
				
			||||||
 | 
					            if not participants.users:
 | 
				
			||||||
 | 
					                self.requests.pop(i)
 | 
				
			||||||
 | 
					                continue
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            self.requests[i].offset += len(participants.participants)
 | 
				
			||||||
 | 
					            users = {user.id: user for user in participants.users}
 | 
				
			||||||
 | 
					            for participant in participants.participants:
 | 
				
			||||||
 | 
					                user = users[participant.user_id]
 | 
				
			||||||
 | 
					                if not self.filter_entity(user) or user.id in self.seen:
 | 
				
			||||||
 | 
					                    continue
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                self.seen.add(participant.user_id)
 | 
				
			||||||
 | 
					                user = users[participant.user_id]
 | 
				
			||||||
 | 
					                user.participant = participant
 | 
				
			||||||
 | 
					                self.buffer.append(user)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class _AdminLogIter(RequestIter):
 | 
				
			||||||
 | 
					    async def _init(
 | 
				
			||||||
 | 
					            self, entity, admins, search, min_id, max_id,
 | 
				
			||||||
 | 
					            join, leave, invite, restrict, unrestrict, ban, unban,
 | 
				
			||||||
 | 
					            promote, demote, info, settings, pinned, edit, delete
 | 
				
			||||||
 | 
					    ):
 | 
				
			||||||
 | 
					        if any((join, leave, invite, restrict, unrestrict, ban, unban,
 | 
				
			||||||
 | 
					                promote, demote, info, settings, pinned, edit, delete)):
 | 
				
			||||||
 | 
					            events_filter = types.ChannelAdminLogEventsFilter(
 | 
				
			||||||
 | 
					                join=join, leave=leave, invite=invite, ban=restrict,
 | 
				
			||||||
 | 
					                unban=unrestrict, kick=ban, unkick=unban, promote=promote,
 | 
				
			||||||
 | 
					                demote=demote, info=info, settings=settings, pinned=pinned,
 | 
				
			||||||
 | 
					                edit=edit, delete=delete
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            events_filter = None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.entity = await self.client.get_input_entity(entity)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        admin_list = []
 | 
				
			||||||
 | 
					        if admins:
 | 
				
			||||||
 | 
					            if not utils.is_list_like(admins):
 | 
				
			||||||
 | 
					                admins = (admins,)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            for admin in admins:
 | 
				
			||||||
 | 
					                admin_list.append(await self.client.get_input_entity(admin))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.request = functions.channels.GetAdminLogRequest(
 | 
				
			||||||
 | 
					            self.entity, q=search or '', min_id=min_id, max_id=max_id,
 | 
				
			||||||
 | 
					            limit=0, events_filter=events_filter, admins=admin_list or None
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    async def _load_next_chunk(self):
 | 
				
			||||||
 | 
					        self.request.limit = min(self.left, 100)
 | 
				
			||||||
 | 
					        r = await self.client(self.request)
 | 
				
			||||||
 | 
					        entities = {utils.get_peer_id(x): x
 | 
				
			||||||
 | 
					                    for x in itertools.chain(r.users, r.chats)}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.request.max_id = min((e.id for e in r.events), default=0)
 | 
				
			||||||
 | 
					        for ev in r.events:
 | 
				
			||||||
 | 
					            if isinstance(ev.action,
 | 
				
			||||||
 | 
					                          types.ChannelAdminLogEventActionEditMessage):
 | 
				
			||||||
 | 
					                ev.action.prev_message._finish_init(
 | 
				
			||||||
 | 
					                    self.client, entities, self.entity)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                ev.action.new_message._finish_init(
 | 
				
			||||||
 | 
					                    self.client, entities, self.entity)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            elif isinstance(ev.action,
 | 
				
			||||||
 | 
					                            types.ChannelAdminLogEventActionDeleteMessage):
 | 
				
			||||||
 | 
					                ev.action.message._finish_init(
 | 
				
			||||||
 | 
					                    self.client, entities, self.entity)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            self.buffer.append(custom.AdminLogEvent(ev, entities))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if len(r.events) < self.request.limit:
 | 
				
			||||||
 | 
					            return True
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class ChatMethods(UserMethods):
 | 
					class ChatMethods(UserMethods):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # region Public methods
 | 
					    # region Public methods
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @async_generator
 | 
					    def iter_participants(
 | 
				
			||||||
    async def iter_participants(
 | 
					 | 
				
			||||||
            self, entity, limit=None, *, search='',
 | 
					            self, entity, limit=None, *, search='',
 | 
				
			||||||
            filter=None, aggressive=False, _total=None):
 | 
					            filter=None, aggressive=False, _total=None):
 | 
				
			||||||
        """
 | 
					        """
 | 
				
			||||||
| 
						 | 
					@ -62,138 +235,23 @@ class ChatMethods(UserMethods):
 | 
				
			||||||
            matched :tl:`ChannelParticipant` type for channels/megagroups
 | 
					            matched :tl:`ChannelParticipant` type for channels/megagroups
 | 
				
			||||||
            or :tl:`ChatParticipants` for normal chats.
 | 
					            or :tl:`ChatParticipants` for normal chats.
 | 
				
			||||||
        """
 | 
					        """
 | 
				
			||||||
        if isinstance(filter, type):
 | 
					        return _ParticipantsIter(
 | 
				
			||||||
            if filter in (types.ChannelParticipantsBanned,
 | 
					            self,
 | 
				
			||||||
                          types.ChannelParticipantsKicked,
 | 
					            limit,
 | 
				
			||||||
                          types.ChannelParticipantsSearch,
 | 
					            entity=entity,
 | 
				
			||||||
                          types.ChannelParticipantsContacts):
 | 
					            filter=filter,
 | 
				
			||||||
                # These require a `q` parameter (support types for convenience)
 | 
					            search=search,
 | 
				
			||||||
                filter = filter('')
 | 
					            aggressive=aggressive
 | 
				
			||||||
            else:
 | 
					        )
 | 
				
			||||||
                filter = filter()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        entity = await self.get_input_entity(entity)
 | 
					 | 
				
			||||||
        if search and (filter
 | 
					 | 
				
			||||||
                       or not isinstance(entity, types.InputPeerChannel)):
 | 
					 | 
				
			||||||
            # We need to 'search' ourselves unless we have a PeerChannel
 | 
					 | 
				
			||||||
            search = search.lower()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            def filter_entity(ent):
 | 
					 | 
				
			||||||
                return search in utils.get_display_name(ent).lower() or\
 | 
					 | 
				
			||||||
                       search in (getattr(ent, 'username', '') or None).lower()
 | 
					 | 
				
			||||||
        else:
 | 
					 | 
				
			||||||
            def filter_entity(ent):
 | 
					 | 
				
			||||||
                return True
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        limit = float('inf') if limit is None else int(limit)
 | 
					 | 
				
			||||||
        if isinstance(entity, types.InputPeerChannel):
 | 
					 | 
				
			||||||
            if _total:
 | 
					 | 
				
			||||||
                _total[0] = (await self(
 | 
					 | 
				
			||||||
                    functions.channels.GetFullChannelRequest(entity)
 | 
					 | 
				
			||||||
                )).full_chat.participants_count
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            if limit == 0:
 | 
					 | 
				
			||||||
                return
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            seen = set()
 | 
					 | 
				
			||||||
            if aggressive and not filter:
 | 
					 | 
				
			||||||
                requests = [functions.channels.GetParticipantsRequest(
 | 
					 | 
				
			||||||
                    channel=entity,
 | 
					 | 
				
			||||||
                    filter=types.ChannelParticipantsSearch(x),
 | 
					 | 
				
			||||||
                    offset=0,
 | 
					 | 
				
			||||||
                    limit=200,
 | 
					 | 
				
			||||||
                    hash=0
 | 
					 | 
				
			||||||
                ) for x in (search or map(chr, range(ord('a'), ord('z') + 1)))]
 | 
					 | 
				
			||||||
            else:
 | 
					 | 
				
			||||||
                requests = [functions.channels.GetParticipantsRequest(
 | 
					 | 
				
			||||||
                    channel=entity,
 | 
					 | 
				
			||||||
                    filter=filter or types.ChannelParticipantsSearch(search),
 | 
					 | 
				
			||||||
                    offset=0,
 | 
					 | 
				
			||||||
                    limit=200,
 | 
					 | 
				
			||||||
                    hash=0
 | 
					 | 
				
			||||||
                )]
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            while requests:
 | 
					 | 
				
			||||||
                # Only care about the limit for the first request
 | 
					 | 
				
			||||||
                # (small amount of people, won't be aggressive).
 | 
					 | 
				
			||||||
                #
 | 
					 | 
				
			||||||
                # Most people won't care about getting exactly 12,345
 | 
					 | 
				
			||||||
                # members so it doesn't really matter not to be 100%
 | 
					 | 
				
			||||||
                # precise with being out of the offset/limit here.
 | 
					 | 
				
			||||||
                requests[0].limit = min(limit - requests[0].offset, 200)
 | 
					 | 
				
			||||||
                if requests[0].offset > limit:
 | 
					 | 
				
			||||||
                    break
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
                results = await self(requests)
 | 
					 | 
				
			||||||
                for i in reversed(range(len(requests))):
 | 
					 | 
				
			||||||
                    participants = results[i]
 | 
					 | 
				
			||||||
                    if not participants.users:
 | 
					 | 
				
			||||||
                        requests.pop(i)
 | 
					 | 
				
			||||||
                    else:
 | 
					 | 
				
			||||||
                        requests[i].offset += len(participants.participants)
 | 
					 | 
				
			||||||
                        users = {user.id: user for user in participants.users}
 | 
					 | 
				
			||||||
                        for participant in participants.participants:
 | 
					 | 
				
			||||||
                            user = users[participant.user_id]
 | 
					 | 
				
			||||||
                            if not filter_entity(user) or user.id in seen:
 | 
					 | 
				
			||||||
                                continue
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
                            seen.add(participant.user_id)
 | 
					 | 
				
			||||||
                            user = users[participant.user_id]
 | 
					 | 
				
			||||||
                            user.participant = participant
 | 
					 | 
				
			||||||
                            await yield_(user)
 | 
					 | 
				
			||||||
                            if len(seen) >= limit:
 | 
					 | 
				
			||||||
                                return
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        elif isinstance(entity, types.InputPeerChat):
 | 
					 | 
				
			||||||
            full = await self(
 | 
					 | 
				
			||||||
                functions.messages.GetFullChatRequest(entity.chat_id))
 | 
					 | 
				
			||||||
            if not isinstance(
 | 
					 | 
				
			||||||
                    full.full_chat.participants, types.ChatParticipants):
 | 
					 | 
				
			||||||
                # ChatParticipantsForbidden won't have ``.participants``
 | 
					 | 
				
			||||||
                if _total:
 | 
					 | 
				
			||||||
                    _total[0] = 0
 | 
					 | 
				
			||||||
                return
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            if _total:
 | 
					 | 
				
			||||||
                _total[0] = len(full.full_chat.participants.participants)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            have = 0
 | 
					 | 
				
			||||||
            users = {user.id: user for user in full.users}
 | 
					 | 
				
			||||||
            for participant in full.full_chat.participants.participants:
 | 
					 | 
				
			||||||
                user = users[participant.user_id]
 | 
					 | 
				
			||||||
                if not filter_entity(user):
 | 
					 | 
				
			||||||
                    continue
 | 
					 | 
				
			||||||
                have += 1
 | 
					 | 
				
			||||||
                if have > limit:
 | 
					 | 
				
			||||||
                    break
 | 
					 | 
				
			||||||
                else:
 | 
					 | 
				
			||||||
                    user = users[participant.user_id]
 | 
					 | 
				
			||||||
                    user.participant = participant
 | 
					 | 
				
			||||||
                    await yield_(user)
 | 
					 | 
				
			||||||
        else:
 | 
					 | 
				
			||||||
            if _total:
 | 
					 | 
				
			||||||
                _total[0] = 1
 | 
					 | 
				
			||||||
            if limit != 0:
 | 
					 | 
				
			||||||
                user = await self.get_entity(entity)
 | 
					 | 
				
			||||||
                if filter_entity(user):
 | 
					 | 
				
			||||||
                    user.participant = None
 | 
					 | 
				
			||||||
                    await yield_(user)
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    async def get_participants(self, *args, **kwargs):
 | 
					    async def get_participants(self, *args, **kwargs):
 | 
				
			||||||
        """
 | 
					        """
 | 
				
			||||||
        Same as `iter_participants`, but returns a
 | 
					        Same as `iter_participants`, but returns a
 | 
				
			||||||
        `TotalList <telethon.helpers.TotalList>` instead.
 | 
					        `TotalList <telethon.helpers.TotalList>` instead.
 | 
				
			||||||
        """
 | 
					        """
 | 
				
			||||||
        total = [0]
 | 
					        return await self.iter_participants(*args, **kwargs).collect()
 | 
				
			||||||
        kwargs['_total'] = total
 | 
					 | 
				
			||||||
        participants = helpers.TotalList()
 | 
					 | 
				
			||||||
        async for x in self.iter_participants(*args, **kwargs):
 | 
					 | 
				
			||||||
            participants.append(x)
 | 
					 | 
				
			||||||
        participants.total = total[0]
 | 
					 | 
				
			||||||
        return participants
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @async_generator
 | 
					    def iter_admin_log(
 | 
				
			||||||
    async def iter_admin_log(
 | 
					 | 
				
			||||||
            self, entity, limit=None, *, max_id=0, min_id=0, search=None,
 | 
					            self, entity, limit=None, *, max_id=0, min_id=0, search=None,
 | 
				
			||||||
            admins=None, join=None, leave=None, invite=None, restrict=None,
 | 
					            admins=None, join=None, leave=None, invite=None, restrict=None,
 | 
				
			||||||
            unrestrict=None, ban=None, unban=None, promote=None, demote=None,
 | 
					            unrestrict=None, ban=None, unban=None, promote=None, demote=None,
 | 
				
			||||||
| 
						 | 
					@ -285,66 +343,34 @@ class ChatMethods(UserMethods):
 | 
				
			||||||
        Yields:
 | 
					        Yields:
 | 
				
			||||||
            Instances of `telethon.tl.custom.adminlogevent.AdminLogEvent`.
 | 
					            Instances of `telethon.tl.custom.adminlogevent.AdminLogEvent`.
 | 
				
			||||||
        """
 | 
					        """
 | 
				
			||||||
        if limit is None:
 | 
					        return _AdminLogIter(
 | 
				
			||||||
            limit = sys.maxsize
 | 
					            self,
 | 
				
			||||||
        elif limit <= 0:
 | 
					            limit,
 | 
				
			||||||
            return
 | 
					            entity=entity,
 | 
				
			||||||
 | 
					            admins=admins,
 | 
				
			||||||
        if any((join, leave, invite, restrict, unrestrict, ban, unban,
 | 
					            search=search,
 | 
				
			||||||
                promote, demote, info, settings, pinned, edit, delete)):
 | 
					            min_id=min_id,
 | 
				
			||||||
            events_filter = types.ChannelAdminLogEventsFilter(
 | 
					            max_id=max_id,
 | 
				
			||||||
                join=join, leave=leave, invite=invite, ban=restrict,
 | 
					            join=join,
 | 
				
			||||||
                unban=unrestrict, kick=ban, unkick=unban, promote=promote,
 | 
					            leave=leave,
 | 
				
			||||||
                demote=demote, info=info, settings=settings, pinned=pinned,
 | 
					            invite=invite,
 | 
				
			||||||
                edit=edit, delete=delete
 | 
					            restrict=restrict,
 | 
				
			||||||
 | 
					            unrestrict=unrestrict,
 | 
				
			||||||
 | 
					            ban=ban,
 | 
				
			||||||
 | 
					            unban=unban,
 | 
				
			||||||
 | 
					            promote=promote,
 | 
				
			||||||
 | 
					            demote=demote,
 | 
				
			||||||
 | 
					            info=info,
 | 
				
			||||||
 | 
					            settings=settings,
 | 
				
			||||||
 | 
					            pinned=pinned,
 | 
				
			||||||
 | 
					            edit=edit,
 | 
				
			||||||
 | 
					            delete=delete
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
        else:
 | 
					 | 
				
			||||||
            events_filter = None
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        entity = await self.get_input_entity(entity)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        admin_list = []
 | 
					 | 
				
			||||||
        if admins:
 | 
					 | 
				
			||||||
            if not utils.is_list_like(admins):
 | 
					 | 
				
			||||||
                admins = (admins,)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            for admin in admins:
 | 
					 | 
				
			||||||
                admin_list.append(await self.get_input_entity(admin))
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        request = functions.channels.GetAdminLogRequest(
 | 
					 | 
				
			||||||
            entity, q=search or '', min_id=min_id, max_id=max_id,
 | 
					 | 
				
			||||||
            limit=0, events_filter=events_filter, admins=admin_list or None
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
        while limit > 0:
 | 
					 | 
				
			||||||
            request.limit = min(limit, 100)
 | 
					 | 
				
			||||||
            result = await self(request)
 | 
					 | 
				
			||||||
            limit -= len(result.events)
 | 
					 | 
				
			||||||
            entities = {utils.get_peer_id(x): x
 | 
					 | 
				
			||||||
                        for x in itertools.chain(result.users, result.chats)}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            request.max_id = min((e.id for e in result.events), default=0)
 | 
					 | 
				
			||||||
            for ev in result.events:
 | 
					 | 
				
			||||||
                if isinstance(ev.action,
 | 
					 | 
				
			||||||
                              types.ChannelAdminLogEventActionEditMessage):
 | 
					 | 
				
			||||||
                    ev.action.prev_message._finish_init(self, entities, entity)
 | 
					 | 
				
			||||||
                    ev.action.new_message._finish_init(self, entities, entity)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
                elif isinstance(ev.action,
 | 
					 | 
				
			||||||
                                types.ChannelAdminLogEventActionDeleteMessage):
 | 
					 | 
				
			||||||
                    ev.action.message._finish_init(self, entities, entity)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
                await yield_(custom.AdminLogEvent(ev, entities))
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            if len(result.events) < request.limit:
 | 
					 | 
				
			||||||
                break
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    async def get_admin_log(self, *args, **kwargs):
 | 
					    async def get_admin_log(self, *args, **kwargs):
 | 
				
			||||||
        """
 | 
					        """
 | 
				
			||||||
        Same as `iter_admin_log`, but returns a ``list`` instead.
 | 
					        Same as `iter_admin_log`, but returns a ``list`` instead.
 | 
				
			||||||
        """
 | 
					        """
 | 
				
			||||||
        admin_log = []
 | 
					        return await self.iter_admin_log(*args, **kwargs).collect()
 | 
				
			||||||
        async for x in self.iter_admin_log(*args, **kwargs):
 | 
					 | 
				
			||||||
            admin_log.append(x)
 | 
					 | 
				
			||||||
        return admin_log
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # endregion
 | 
					    # endregion
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -1,18 +1,101 @@
 | 
				
			||||||
import itertools
 | 
					import itertools
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from async_generator import async_generator, yield_
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
from .users import UserMethods
 | 
					from .users import UserMethods
 | 
				
			||||||
from .. import utils, helpers
 | 
					from .. import utils
 | 
				
			||||||
 | 
					from ..requestiter import RequestIter
 | 
				
			||||||
from ..tl import types, functions, custom
 | 
					from ..tl import types, functions, custom
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class _DialogsIter(RequestIter):
 | 
				
			||||||
 | 
					    async def _init(
 | 
				
			||||||
 | 
					            self, offset_date, offset_id, offset_peer, ignore_migrated
 | 
				
			||||||
 | 
					    ):
 | 
				
			||||||
 | 
					        self.request = functions.messages.GetDialogsRequest(
 | 
				
			||||||
 | 
					            offset_date=offset_date,
 | 
				
			||||||
 | 
					            offset_id=offset_id,
 | 
				
			||||||
 | 
					            offset_peer=offset_peer,
 | 
				
			||||||
 | 
					            limit=1,
 | 
				
			||||||
 | 
					            hash=0
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if self.limit == 0:
 | 
				
			||||||
 | 
					            # Special case, get a single dialog and determine count
 | 
				
			||||||
 | 
					            dialogs = await self.client(self.request)
 | 
				
			||||||
 | 
					            self.total = getattr(dialogs, 'count', len(dialogs.dialogs))
 | 
				
			||||||
 | 
					            raise StopAsyncIteration
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.seen = set()
 | 
				
			||||||
 | 
					        self.offset_date = offset_date
 | 
				
			||||||
 | 
					        self.ignore_migrated = ignore_migrated
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    async def _load_next_chunk(self):
 | 
				
			||||||
 | 
					        self.request.limit = min(self.left, 100)
 | 
				
			||||||
 | 
					        r = await self.client(self.request)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.total = getattr(r, 'count', len(r.dialogs))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        entities = {utils.get_peer_id(x): x
 | 
				
			||||||
 | 
					                    for x in itertools.chain(r.users, r.chats)}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        messages = {}
 | 
				
			||||||
 | 
					        for m in r.messages:
 | 
				
			||||||
 | 
					            m._finish_init(self, entities, None)
 | 
				
			||||||
 | 
					            messages[m.id] = m
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        for d in r.dialogs:
 | 
				
			||||||
 | 
					            # We check the offset date here because Telegram may ignore it
 | 
				
			||||||
 | 
					            if self.offset_date:
 | 
				
			||||||
 | 
					                date = getattr(messages.get(
 | 
				
			||||||
 | 
					                    d.top_message, None), 'date', None)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                if not date or date.timestamp() > self.offset_date.timestamp():
 | 
				
			||||||
 | 
					                    continue
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            peer_id = utils.get_peer_id(d.peer)
 | 
				
			||||||
 | 
					            if peer_id not in self.seen:
 | 
				
			||||||
 | 
					                self.seen.add(peer_id)
 | 
				
			||||||
 | 
					                cd = custom.Dialog(self, d, entities, messages)
 | 
				
			||||||
 | 
					                if cd.dialog.pts:
 | 
				
			||||||
 | 
					                    self.client._channel_pts[cd.id] = cd.dialog.pts
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                if not self.ignore_migrated or getattr(
 | 
				
			||||||
 | 
					                        cd.entity, 'migrated_to', None) is None:
 | 
				
			||||||
 | 
					                    self.buffer.append(cd)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if len(r.dialogs) < self.request.limit\
 | 
				
			||||||
 | 
					                or not isinstance(r, types.messages.DialogsSlice):
 | 
				
			||||||
 | 
					            # Less than we requested means we reached the end, or
 | 
				
			||||||
 | 
					            # we didn't get a DialogsSlice which means we got all.
 | 
				
			||||||
 | 
					            return True
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if self.request.offset_id == r.messages[-1].id:
 | 
				
			||||||
 | 
					            # In some very rare cases this will get stuck in an infinite
 | 
				
			||||||
 | 
					            # loop, where the offsets will get reused over and over. If
 | 
				
			||||||
 | 
					            # the new offset is the same as the one before, break already.
 | 
				
			||||||
 | 
					            return True
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.request.offset_id = r.messages[-1].id
 | 
				
			||||||
 | 
					        self.request.exclude_pinned = True
 | 
				
			||||||
 | 
					        self.request.offset_date = r.messages[-1].date
 | 
				
			||||||
 | 
					        self.request.offset_peer =\
 | 
				
			||||||
 | 
					            entities[utils.get_peer_id(r.dialogs[-1].peer)]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class _DraftsIter(RequestIter):
 | 
				
			||||||
 | 
					    async def _init(self, **kwargs):
 | 
				
			||||||
 | 
					        r = await self.client(functions.messages.GetAllDraftsRequest())
 | 
				
			||||||
 | 
					        self.buffer.extend(custom.Draft._from_update(self.client, u)
 | 
				
			||||||
 | 
					                           for u in r.updates)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    async def _load_next_chunk(self):
 | 
				
			||||||
 | 
					        return []
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class DialogMethods(UserMethods):
 | 
					class DialogMethods(UserMethods):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # region Public methods
 | 
					    # region Public methods
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @async_generator
 | 
					    def iter_dialogs(
 | 
				
			||||||
    async def iter_dialogs(
 | 
					 | 
				
			||||||
            self, limit=None, *, offset_date=None, offset_id=0,
 | 
					            self, limit=None, *, offset_date=None, offset_id=0,
 | 
				
			||||||
            offset_peer=types.InputPeerEmpty(), ignore_migrated=False,
 | 
					            offset_peer=types.InputPeerEmpty(), ignore_migrated=False,
 | 
				
			||||||
            _total=None):
 | 
					            _total=None):
 | 
				
			||||||
| 
						 | 
					@ -50,99 +133,23 @@ class DialogMethods(UserMethods):
 | 
				
			||||||
        Yields:
 | 
					        Yields:
 | 
				
			||||||
            Instances of `telethon.tl.custom.dialog.Dialog`.
 | 
					            Instances of `telethon.tl.custom.dialog.Dialog`.
 | 
				
			||||||
        """
 | 
					        """
 | 
				
			||||||
        limit = float('inf') if limit is None else int(limit)
 | 
					        return _DialogsIter(
 | 
				
			||||||
        if limit == 0:
 | 
					            self,
 | 
				
			||||||
            if not _total:
 | 
					            limit,
 | 
				
			||||||
                return
 | 
					 | 
				
			||||||
            # Special case, get a single dialog and determine count
 | 
					 | 
				
			||||||
            dialogs = await self(functions.messages.GetDialogsRequest(
 | 
					 | 
				
			||||||
            offset_date=offset_date,
 | 
					            offset_date=offset_date,
 | 
				
			||||||
            offset_id=offset_id,
 | 
					            offset_id=offset_id,
 | 
				
			||||||
            offset_peer=offset_peer,
 | 
					            offset_peer=offset_peer,
 | 
				
			||||||
                limit=1,
 | 
					            ignore_migrated=ignore_migrated
 | 
				
			||||||
                hash=0
 | 
					 | 
				
			||||||
            ))
 | 
					 | 
				
			||||||
            _total[0] = getattr(dialogs, 'count', len(dialogs.dialogs))
 | 
					 | 
				
			||||||
            return
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        seen = set()
 | 
					 | 
				
			||||||
        req = functions.messages.GetDialogsRequest(
 | 
					 | 
				
			||||||
            offset_date=offset_date,
 | 
					 | 
				
			||||||
            offset_id=offset_id,
 | 
					 | 
				
			||||||
            offset_peer=offset_peer,
 | 
					 | 
				
			||||||
            limit=0,
 | 
					 | 
				
			||||||
            hash=0
 | 
					 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
        while len(seen) < limit:
 | 
					 | 
				
			||||||
            req.limit = min(limit - len(seen), 100)
 | 
					 | 
				
			||||||
            r = await self(req)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            if _total:
 | 
					 | 
				
			||||||
                _total[0] = getattr(r, 'count', len(r.dialogs))
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            entities = {utils.get_peer_id(x): x
 | 
					 | 
				
			||||||
                        for x in itertools.chain(r.users, r.chats)}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            messages = {}
 | 
					 | 
				
			||||||
            for m in r.messages:
 | 
					 | 
				
			||||||
                m._finish_init(self, entities, None)
 | 
					 | 
				
			||||||
                messages[m.id] = m
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            # Happens when there are pinned dialogs
 | 
					 | 
				
			||||||
            if len(r.dialogs) > limit:
 | 
					 | 
				
			||||||
                r.dialogs = r.dialogs[:limit]
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            for d in r.dialogs:
 | 
					 | 
				
			||||||
                if offset_date:
 | 
					 | 
				
			||||||
                    date = getattr(messages.get(
 | 
					 | 
				
			||||||
                        d.top_message, None), 'date', None)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
                    if not date or date.timestamp() > offset_date.timestamp():
 | 
					 | 
				
			||||||
                        continue
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
                peer_id = utils.get_peer_id(d.peer)
 | 
					 | 
				
			||||||
                if peer_id not in seen:
 | 
					 | 
				
			||||||
                    seen.add(peer_id)
 | 
					 | 
				
			||||||
                    cd = custom.Dialog(self, d, entities, messages)
 | 
					 | 
				
			||||||
                    if cd.dialog.pts:
 | 
					 | 
				
			||||||
                        self._channel_pts[cd.id] = cd.dialog.pts
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
                    if not ignore_migrated or getattr(
 | 
					 | 
				
			||||||
                            cd.entity, 'migrated_to', None) is None:
 | 
					 | 
				
			||||||
                        await yield_(cd)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            if len(r.dialogs) < req.limit\
 | 
					 | 
				
			||||||
                    or not isinstance(r, types.messages.DialogsSlice):
 | 
					 | 
				
			||||||
                # Less than we requested means we reached the end, or
 | 
					 | 
				
			||||||
                # we didn't get a DialogsSlice which means we got all.
 | 
					 | 
				
			||||||
                break
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            req.offset_date = r.messages[-1].date
 | 
					 | 
				
			||||||
            req.offset_peer = entities[utils.get_peer_id(r.dialogs[-1].peer)]
 | 
					 | 
				
			||||||
            if req.offset_id == r.messages[-1].id:
 | 
					 | 
				
			||||||
                # In some very rare cases this will get stuck in an infinite
 | 
					 | 
				
			||||||
                # loop, where the offsets will get reused over and over. If
 | 
					 | 
				
			||||||
                # the new offset is the same as the one before, break already.
 | 
					 | 
				
			||||||
                break
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            req.offset_id = r.messages[-1].id
 | 
					 | 
				
			||||||
            req.exclude_pinned = True
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    async def get_dialogs(self, *args, **kwargs):
 | 
					    async def get_dialogs(self, *args, **kwargs):
 | 
				
			||||||
        """
 | 
					        """
 | 
				
			||||||
        Same as `iter_dialogs`, but returns a
 | 
					        Same as `iter_dialogs`, but returns a
 | 
				
			||||||
        `TotalList <telethon.helpers.TotalList>` instead.
 | 
					        `TotalList <telethon.helpers.TotalList>` instead.
 | 
				
			||||||
        """
 | 
					        """
 | 
				
			||||||
        total = [0]
 | 
					        return await self.iter_dialogs(*args, **kwargs).collect()
 | 
				
			||||||
        kwargs['_total'] = total
 | 
					 | 
				
			||||||
        dialogs = helpers.TotalList()
 | 
					 | 
				
			||||||
        async for x in self.iter_dialogs(*args, **kwargs):
 | 
					 | 
				
			||||||
            dialogs.append(x)
 | 
					 | 
				
			||||||
        dialogs.total = total[0]
 | 
					 | 
				
			||||||
        return dialogs
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @async_generator
 | 
					    def iter_drafts(self):
 | 
				
			||||||
    async def iter_drafts(self):
 | 
					 | 
				
			||||||
        """
 | 
					        """
 | 
				
			||||||
        Iterator over all open draft messages.
 | 
					        Iterator over all open draft messages.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -151,18 +158,14 @@ class DialogMethods(UserMethods):
 | 
				
			||||||
        to change the message or `telethon.tl.custom.draft.Draft.delete`
 | 
					        to change the message or `telethon.tl.custom.draft.Draft.delete`
 | 
				
			||||||
        among other things.
 | 
					        among other things.
 | 
				
			||||||
        """
 | 
					        """
 | 
				
			||||||
        r = await self(functions.messages.GetAllDraftsRequest())
 | 
					        # TODO Passing a limit here makes no sense
 | 
				
			||||||
        for update in r.updates:
 | 
					        return _DraftsIter(self, None)
 | 
				
			||||||
            await yield_(custom.Draft._from_update(self, update))
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    async def get_drafts(self):
 | 
					    async def get_drafts(self):
 | 
				
			||||||
        """
 | 
					        """
 | 
				
			||||||
        Same as :meth:`iter_drafts`, but returns a list instead.
 | 
					        Same as :meth:`iter_drafts`, but returns a list instead.
 | 
				
			||||||
        """
 | 
					        """
 | 
				
			||||||
        result = []
 | 
					        return await self.iter_drafts().collect()
 | 
				
			||||||
        async for x in self.iter_drafts():
 | 
					 | 
				
			||||||
            result.append(x)
 | 
					 | 
				
			||||||
        return result
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def conversation(
 | 
					    def conversation(
 | 
				
			||||||
            self, entity,
 | 
					            self, entity,
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -1,14 +1,282 @@
 | 
				
			||||||
import asyncio
 | 
					 | 
				
			||||||
import itertools
 | 
					import itertools
 | 
				
			||||||
import time
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
from async_generator import async_generator, yield_
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
from .messageparse import MessageParseMethods
 | 
					from .messageparse import MessageParseMethods
 | 
				
			||||||
from .uploads import UploadMethods
 | 
					from .uploads import UploadMethods
 | 
				
			||||||
from .buttons import ButtonMethods
 | 
					from .buttons import ButtonMethods
 | 
				
			||||||
from .. import helpers, utils, errors
 | 
					from .. import helpers, utils, errors
 | 
				
			||||||
from ..tl import types, functions
 | 
					from ..tl import types, functions
 | 
				
			||||||
 | 
					from ..requestiter import RequestIter
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# TODO Maybe RequestIter could rather have the update offset here?
 | 
				
			||||||
 | 
					#      Maybe init should return the request to be used and it be
 | 
				
			||||||
 | 
					#      called automatically? And another method to just process it.
 | 
				
			||||||
 | 
					class _MessagesIter(RequestIter):
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    Common factor for all requests that need to iterate over messages.
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    async def _init(
 | 
				
			||||||
 | 
					            self, entity, offset_id, min_id, max_id, from_user,
 | 
				
			||||||
 | 
					            batch_size, offset_date, add_offset, filter, search
 | 
				
			||||||
 | 
					    ):
 | 
				
			||||||
 | 
					        # Note that entity being ``None`` will perform a global search.
 | 
				
			||||||
 | 
					        if entity:
 | 
				
			||||||
 | 
					            self.entity = await self.client.get_input_entity(entity)
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            self.entity = None
 | 
				
			||||||
 | 
					            if self.reverse:
 | 
				
			||||||
 | 
					                raise ValueError('Cannot reverse global search')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # Telegram doesn't like min_id/max_id. If these IDs are low enough
 | 
				
			||||||
 | 
					        # (starting from last_id - 100), the request will return nothing.
 | 
				
			||||||
 | 
					        #
 | 
				
			||||||
 | 
					        # We can emulate their behaviour locally by setting offset = max_id
 | 
				
			||||||
 | 
					        # and simply stopping once we hit a message with ID <= min_id.
 | 
				
			||||||
 | 
					        if self.reverse:
 | 
				
			||||||
 | 
					            offset_id = max(offset_id, min_id)
 | 
				
			||||||
 | 
					            if offset_id and max_id:
 | 
				
			||||||
 | 
					                if max_id - offset_id <= 1:
 | 
				
			||||||
 | 
					                    raise StopAsyncIteration
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            if not max_id:
 | 
				
			||||||
 | 
					                max_id = float('inf')
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            offset_id = max(offset_id, max_id)
 | 
				
			||||||
 | 
					            if offset_id and min_id:
 | 
				
			||||||
 | 
					                if offset_id - min_id <= 1:
 | 
				
			||||||
 | 
					                    raise StopAsyncIteration
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if self.reverse:
 | 
				
			||||||
 | 
					            if offset_id:
 | 
				
			||||||
 | 
					                offset_id += 1
 | 
				
			||||||
 | 
					            else:
 | 
				
			||||||
 | 
					                offset_id = 1
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if from_user:
 | 
				
			||||||
 | 
					            from_user = await self.client.get_input_entity(from_user)
 | 
				
			||||||
 | 
					            if not isinstance(from_user, (
 | 
				
			||||||
 | 
					                    types.InputPeerUser, types.InputPeerSelf)):
 | 
				
			||||||
 | 
					                from_user = None  # Ignore from_user unless it's a user
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if from_user:
 | 
				
			||||||
 | 
					            self.from_id = await self.client.get_peer_id(from_user)
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            self.from_id = None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if not self.entity:
 | 
				
			||||||
 | 
					            self.request = functions.messages.SearchGlobalRequest(
 | 
				
			||||||
 | 
					                q=search or '',
 | 
				
			||||||
 | 
					                offset_date=offset_date,
 | 
				
			||||||
 | 
					                offset_peer=types.InputPeerEmpty(),
 | 
				
			||||||
 | 
					                offset_id=offset_id,
 | 
				
			||||||
 | 
					                limit=1
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					        elif search is not None or filter or from_user:
 | 
				
			||||||
 | 
					            if filter is None:
 | 
				
			||||||
 | 
					                filter = types.InputMessagesFilterEmpty()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            # Telegram completely ignores `from_id` in private chats
 | 
				
			||||||
 | 
					            if isinstance(
 | 
				
			||||||
 | 
					                    self.entity, (types.InputPeerUser, types.InputPeerSelf)):
 | 
				
			||||||
 | 
					                # Don't bother sending `from_user` (it's ignored anyway),
 | 
				
			||||||
 | 
					                # but keep `from_id` defined above to check it locally.
 | 
				
			||||||
 | 
					                from_user = None
 | 
				
			||||||
 | 
					            else:
 | 
				
			||||||
 | 
					                # Do send `from_user` to do the filtering server-side,
 | 
				
			||||||
 | 
					                # and set `from_id` to None to avoid checking it locally.
 | 
				
			||||||
 | 
					                self.from_id = None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            self.request = functions.messages.SearchRequest(
 | 
				
			||||||
 | 
					                peer=self.entity,
 | 
				
			||||||
 | 
					                q=search or '',
 | 
				
			||||||
 | 
					                filter=filter() if isinstance(filter, type) else filter,
 | 
				
			||||||
 | 
					                min_date=None,
 | 
				
			||||||
 | 
					                max_date=offset_date,
 | 
				
			||||||
 | 
					                offset_id=offset_id,
 | 
				
			||||||
 | 
					                add_offset=add_offset,
 | 
				
			||||||
 | 
					                limit=0,  # Search actually returns 0 items if we ask it to
 | 
				
			||||||
 | 
					                max_id=0,
 | 
				
			||||||
 | 
					                min_id=0,
 | 
				
			||||||
 | 
					                hash=0,
 | 
				
			||||||
 | 
					                from_id=from_user
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            self.request = functions.messages.GetHistoryRequest(
 | 
				
			||||||
 | 
					                peer=self.entity,
 | 
				
			||||||
 | 
					                limit=1,
 | 
				
			||||||
 | 
					                offset_date=offset_date,
 | 
				
			||||||
 | 
					                offset_id=offset_id,
 | 
				
			||||||
 | 
					                min_id=0,
 | 
				
			||||||
 | 
					                max_id=0,
 | 
				
			||||||
 | 
					                add_offset=add_offset,
 | 
				
			||||||
 | 
					                hash=0
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if self.limit == 0:
 | 
				
			||||||
 | 
					            # No messages, but we still need to know the total message count
 | 
				
			||||||
 | 
					            result = await self.client(self.request)
 | 
				
			||||||
 | 
					            if isinstance(result, types.messages.MessagesNotModified):
 | 
				
			||||||
 | 
					                self.total = result.count
 | 
				
			||||||
 | 
					            else:
 | 
				
			||||||
 | 
					                self.total = getattr(result, 'count', len(result.messages))
 | 
				
			||||||
 | 
					            raise StopAsyncIteration
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if self.wait_time is None:
 | 
				
			||||||
 | 
					            self.wait_time = 1 if self.limit > 3000 else 0
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # Telegram has a hard limit of 100.
 | 
				
			||||||
 | 
					        # We don't need to fetch 100 if the limit is less.
 | 
				
			||||||
 | 
					        self.batch_size = min(max(batch_size, 1), min(100, self.limit))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # When going in reverse we need an offset of `-limit`, but we
 | 
				
			||||||
 | 
					        # also want to respect what the user passed, so add them together.
 | 
				
			||||||
 | 
					        if self.reverse:
 | 
				
			||||||
 | 
					            self.request.add_offset -= self.batch_size
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.add_offset = add_offset
 | 
				
			||||||
 | 
					        self.max_id = max_id
 | 
				
			||||||
 | 
					        self.min_id = min_id
 | 
				
			||||||
 | 
					        self.last_id = 0 if self.reverse else float('inf')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    async def _load_next_chunk(self):
 | 
				
			||||||
 | 
					        self.request.limit = min(self.left, self.batch_size)
 | 
				
			||||||
 | 
					        if self.reverse and self.request.limit != self.batch_size:
 | 
				
			||||||
 | 
					            # Remember that we need -limit when going in reverse
 | 
				
			||||||
 | 
					            self.request.add_offset = self.add_offset - self.request.limit
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        r = await self.client(self.request)
 | 
				
			||||||
 | 
					        self.total = getattr(r, 'count', len(r.messages))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        entities = {utils.get_peer_id(x): x
 | 
				
			||||||
 | 
					                    for x in itertools.chain(r.users, r.chats)}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        messages = reversed(r.messages) if self.reverse else r.messages
 | 
				
			||||||
 | 
					        for message in messages:
 | 
				
			||||||
 | 
					            if (isinstance(message, types.MessageEmpty)
 | 
				
			||||||
 | 
					                    or self.from_id and message.from_id != self.from_id):
 | 
				
			||||||
 | 
					                continue
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            if not self._message_in_range(message):
 | 
				
			||||||
 | 
					                return True
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            # There has been reports that on bad connections this method
 | 
				
			||||||
 | 
					            # was returning duplicated IDs sometimes. Using ``last_id``
 | 
				
			||||||
 | 
					            # is an attempt to avoid these duplicates, since the message
 | 
				
			||||||
 | 
					            # IDs are returned in descending order (or asc if reverse).
 | 
				
			||||||
 | 
					            self.last_id = message.id
 | 
				
			||||||
 | 
					            message._finish_init(self.client, entities, self.entity)
 | 
				
			||||||
 | 
					            self.buffer.append(message)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if len(r.messages) < self.request.limit:
 | 
				
			||||||
 | 
					            return True
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # Get the last message that's not empty (in some rare cases
 | 
				
			||||||
 | 
					        # it can happen that the last message is :tl:`MessageEmpty`)
 | 
				
			||||||
 | 
					        if self.buffer:
 | 
				
			||||||
 | 
					            self._update_offset(self.buffer[-1])
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            # There are some cases where all the messages we get start
 | 
				
			||||||
 | 
					            # being empty. This can happen on migrated mega-groups if
 | 
				
			||||||
 | 
					            # the history was cleared, and we're using search. Telegram
 | 
				
			||||||
 | 
					            # acts incredibly weird sometimes. Messages are returned but
 | 
				
			||||||
 | 
					            # only "empty", not their contents. If this is the case we
 | 
				
			||||||
 | 
					            # should just give up since there won't be any new Message.
 | 
				
			||||||
 | 
					            return True
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def _message_in_range(self, message):
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        Determine whether the given message is in the range or
 | 
				
			||||||
 | 
					        it should be ignored (and avoid loading more chunks).
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        # No entity means message IDs between chats may vary
 | 
				
			||||||
 | 
					        if self.entity:
 | 
				
			||||||
 | 
					            if self.reverse:
 | 
				
			||||||
 | 
					                if message.id <= self.last_id or message.id >= self.max_id:
 | 
				
			||||||
 | 
					                    return False
 | 
				
			||||||
 | 
					            else:
 | 
				
			||||||
 | 
					                if message.id >= self.last_id or message.id <= self.min_id:
 | 
				
			||||||
 | 
					                    return False
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return True
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def _update_offset(self, last_message):
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        After making the request, update its offset with the last message.
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        self.request.offset_id = last_message.id
 | 
				
			||||||
 | 
					        if self.reverse:
 | 
				
			||||||
 | 
					            # We want to skip the one we already have
 | 
				
			||||||
 | 
					            self.request.offset_id += 1
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if isinstance(self.request, functions.messages.SearchRequest):
 | 
				
			||||||
 | 
					            # Unlike getHistory and searchGlobal that use *offset* date,
 | 
				
			||||||
 | 
					            # this is *max* date. This means that doing a search in reverse
 | 
				
			||||||
 | 
					            # will break it. Since it's not really needed once we're going
 | 
				
			||||||
 | 
					            # (only for the first request), it's safe to just clear it off.
 | 
				
			||||||
 | 
					            self.request.max_date = None
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            # getHistory and searchGlobal call it offset_date
 | 
				
			||||||
 | 
					            self.request.offset_date = last_message.date
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if isinstance(self.request, functions.messages.SearchGlobalRequest):
 | 
				
			||||||
 | 
					            self.request.offset_peer = last_message.input_chat
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class _IDsIter(RequestIter):
 | 
				
			||||||
 | 
					    async def _init(self, entity, ids):
 | 
				
			||||||
 | 
					        # TODO We never actually split IDs in chunks, but maybe we should
 | 
				
			||||||
 | 
					        if not utils.is_list_like(ids):
 | 
				
			||||||
 | 
					            self.ids = [ids]
 | 
				
			||||||
 | 
					        elif not ids:
 | 
				
			||||||
 | 
					            raise StopAsyncIteration
 | 
				
			||||||
 | 
					        elif self.reverse:
 | 
				
			||||||
 | 
					            self.ids = list(reversed(ids))
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            self.ids = ids
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if entity:
 | 
				
			||||||
 | 
					            entity = await self.client.get_input_entity(entity)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.total = len(ids)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        from_id = None  # By default, no need to validate from_id
 | 
				
			||||||
 | 
					        if isinstance(entity, (types.InputChannel, types.InputPeerChannel)):
 | 
				
			||||||
 | 
					            try:
 | 
				
			||||||
 | 
					                r = await self.client(
 | 
				
			||||||
 | 
					                    functions.channels.GetMessagesRequest(entity, ids))
 | 
				
			||||||
 | 
					            except errors.MessageIdsEmptyError:
 | 
				
			||||||
 | 
					                # All IDs were invalid, use a dummy result
 | 
				
			||||||
 | 
					                r = types.messages.MessagesNotModified(len(ids))
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            r = await self.client(functions.messages.GetMessagesRequest(ids))
 | 
				
			||||||
 | 
					            if entity:
 | 
				
			||||||
 | 
					                from_id = await self.client.get_peer_id(entity)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if isinstance(r, types.messages.MessagesNotModified):
 | 
				
			||||||
 | 
					            self.buffer.extend(None for _ in ids)
 | 
				
			||||||
 | 
					            return
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        entities = {utils.get_peer_id(x): x
 | 
				
			||||||
 | 
					                    for x in itertools.chain(r.users, r.chats)}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # Telegram seems to return the messages in the order in which
 | 
				
			||||||
 | 
					        # we asked them for, so we don't need to check it ourselves,
 | 
				
			||||||
 | 
					        # unless some messages were invalid in which case Telegram
 | 
				
			||||||
 | 
					        # may decide to not send them at all.
 | 
				
			||||||
 | 
					        #
 | 
				
			||||||
 | 
					        # The passed message IDs may not belong to the desired entity
 | 
				
			||||||
 | 
					        # since the user can enter arbitrary numbers which can belong to
 | 
				
			||||||
 | 
					        # arbitrary chats. Validate these unless ``from_id is None``.
 | 
				
			||||||
 | 
					        for message in r.messages:
 | 
				
			||||||
 | 
					            if isinstance(message, types.MessageEmpty) or (
 | 
				
			||||||
 | 
					                    from_id and message.chat_id != from_id):
 | 
				
			||||||
 | 
					                self.buffer.append(None)
 | 
				
			||||||
 | 
					            else:
 | 
				
			||||||
 | 
					                message._finish_init(self.client, entities, entity)
 | 
				
			||||||
 | 
					                self.buffer.append(message)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    async def _load_next_chunk(self):
 | 
				
			||||||
 | 
					        return True  # no next chunk, all done in init
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class MessageMethods(UploadMethods, ButtonMethods, MessageParseMethods):
 | 
					class MessageMethods(UploadMethods, ButtonMethods, MessageParseMethods):
 | 
				
			||||||
| 
						 | 
					@ -17,8 +285,7 @@ class MessageMethods(UploadMethods, ButtonMethods, MessageParseMethods):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # region Message retrieval
 | 
					    # region Message retrieval
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @async_generator
 | 
					    def iter_messages(
 | 
				
			||||||
    async def iter_messages(
 | 
					 | 
				
			||||||
            self, entity, limit=None, *, offset_date=None, offset_id=0,
 | 
					            self, entity, limit=None, *, offset_date=None, offset_id=0,
 | 
				
			||||||
            max_id=0, min_id=0, add_offset=0, search=None, filter=None,
 | 
					            max_id=0, min_id=0, add_offset=0, search=None, filter=None,
 | 
				
			||||||
            from_user=None, batch_size=100, wait_time=None, ids=None,
 | 
					            from_user=None, batch_size=100, wait_time=None, ids=None,
 | 
				
			||||||
| 
						 | 
					@ -133,208 +400,26 @@ class MessageMethods(UploadMethods, ButtonMethods, MessageParseMethods):
 | 
				
			||||||
            an higher limit, so you're free to set the ``batch_size`` that
 | 
					            an higher limit, so you're free to set the ``batch_size`` that
 | 
				
			||||||
            you think may be good.
 | 
					            you think may be good.
 | 
				
			||||||
        """
 | 
					        """
 | 
				
			||||||
        # Note that entity being ``None`` is intended to get messages by
 | 
					 | 
				
			||||||
        # ID under no specific chat, and also to request a global search.
 | 
					 | 
				
			||||||
        if entity:
 | 
					 | 
				
			||||||
            entity = await self.get_input_entity(entity)
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
        if ids:
 | 
					        if ids is not None:
 | 
				
			||||||
            if not utils.is_list_like(ids):
 | 
					            return _IDsIter(self, limit, entity=entity, ids=ids)
 | 
				
			||||||
                ids = (ids,)
 | 
					 | 
				
			||||||
            if reverse:
 | 
					 | 
				
			||||||
                ids = list(reversed(ids))
 | 
					 | 
				
			||||||
            async for x in self._iter_ids(entity, ids, total=_total):
 | 
					 | 
				
			||||||
                await yield_(x)
 | 
					 | 
				
			||||||
            return
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # Telegram doesn't like min_id/max_id. If these IDs are low enough
 | 
					        return _MessagesIter(
 | 
				
			||||||
        # (starting from last_id - 100), the request will return nothing.
 | 
					            client=self,
 | 
				
			||||||
        #
 | 
					            reverse=reverse,
 | 
				
			||||||
        # We can emulate their behaviour locally by setting offset = max_id
 | 
					            wait_time=wait_time,
 | 
				
			||||||
        # and simply stopping once we hit a message with ID <= min_id.
 | 
					            limit=limit,
 | 
				
			||||||
        if reverse:
 | 
					            entity=entity,
 | 
				
			||||||
            offset_id = max(offset_id, min_id)
 | 
					            offset_id=offset_id,
 | 
				
			||||||
            if offset_id and max_id:
 | 
					            min_id=min_id,
 | 
				
			||||||
                if max_id - offset_id <= 1:
 | 
					            max_id=max_id,
 | 
				
			||||||
                    return
 | 
					            from_user=from_user,
 | 
				
			||||||
 | 
					            batch_size=batch_size,
 | 
				
			||||||
            if not max_id:
 | 
					 | 
				
			||||||
                max_id = float('inf')
 | 
					 | 
				
			||||||
        else:
 | 
					 | 
				
			||||||
            offset_id = max(offset_id, max_id)
 | 
					 | 
				
			||||||
            if offset_id and min_id:
 | 
					 | 
				
			||||||
                if offset_id - min_id <= 1:
 | 
					 | 
				
			||||||
                    return
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        if reverse:
 | 
					 | 
				
			||||||
            if offset_id:
 | 
					 | 
				
			||||||
                offset_id += 1
 | 
					 | 
				
			||||||
            else:
 | 
					 | 
				
			||||||
                offset_id = 1
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        if from_user:
 | 
					 | 
				
			||||||
            from_user = await self.get_input_entity(from_user)
 | 
					 | 
				
			||||||
            if not isinstance(from_user, (
 | 
					 | 
				
			||||||
                    types.InputPeerUser, types.InputPeerSelf)):
 | 
					 | 
				
			||||||
                from_user = None  # Ignore from_user unless it's a user
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        from_id = (await self.get_peer_id(from_user)) if from_user else None
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        limit = float('inf') if limit is None else int(limit)
 | 
					 | 
				
			||||||
        if not entity:
 | 
					 | 
				
			||||||
            if reverse:
 | 
					 | 
				
			||||||
                raise ValueError('Cannot reverse global search')
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            reverse = None
 | 
					 | 
				
			||||||
            request = functions.messages.SearchGlobalRequest(
 | 
					 | 
				
			||||||
                q=search or '',
 | 
					 | 
				
			||||||
            offset_date=offset_date,
 | 
					            offset_date=offset_date,
 | 
				
			||||||
                offset_peer=types.InputPeerEmpty(),
 | 
					 | 
				
			||||||
                offset_id=offset_id,
 | 
					 | 
				
			||||||
                limit=1
 | 
					 | 
				
			||||||
            )
 | 
					 | 
				
			||||||
        elif search is not None or filter or from_user:
 | 
					 | 
				
			||||||
            if filter is None:
 | 
					 | 
				
			||||||
                filter = types.InputMessagesFilterEmpty()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            # Telegram completely ignores `from_id` in private chats
 | 
					 | 
				
			||||||
            if isinstance(entity, (types.InputPeerUser, types.InputPeerSelf)):
 | 
					 | 
				
			||||||
                # Don't bother sending `from_user` (it's ignored anyway),
 | 
					 | 
				
			||||||
                # but keep `from_id` defined above to check it locally.
 | 
					 | 
				
			||||||
                from_user = None
 | 
					 | 
				
			||||||
            else:
 | 
					 | 
				
			||||||
                # Do send `from_user` to do the filtering server-side,
 | 
					 | 
				
			||||||
                # and set `from_id` to None to avoid checking it locally.
 | 
					 | 
				
			||||||
                from_id = None
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            request = functions.messages.SearchRequest(
 | 
					 | 
				
			||||||
                peer=entity,
 | 
					 | 
				
			||||||
                q=search or '',
 | 
					 | 
				
			||||||
                filter=filter() if isinstance(filter, type) else filter,
 | 
					 | 
				
			||||||
                min_date=None,
 | 
					 | 
				
			||||||
                max_date=offset_date,
 | 
					 | 
				
			||||||
                offset_id=offset_id,
 | 
					 | 
				
			||||||
            add_offset=add_offset,
 | 
					            add_offset=add_offset,
 | 
				
			||||||
                limit=0,  # Search actually returns 0 items if we ask it to
 | 
					            filter=filter,
 | 
				
			||||||
                max_id=0,
 | 
					            search=search
 | 
				
			||||||
                min_id=0,
 | 
					 | 
				
			||||||
                hash=0,
 | 
					 | 
				
			||||||
                from_id=from_user
 | 
					 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
        else:
 | 
					 | 
				
			||||||
            request = functions.messages.GetHistoryRequest(
 | 
					 | 
				
			||||||
                peer=entity,
 | 
					 | 
				
			||||||
                limit=1,
 | 
					 | 
				
			||||||
                offset_date=offset_date,
 | 
					 | 
				
			||||||
                offset_id=offset_id,
 | 
					 | 
				
			||||||
                min_id=0,
 | 
					 | 
				
			||||||
                max_id=0,
 | 
					 | 
				
			||||||
                add_offset=add_offset,
 | 
					 | 
				
			||||||
                hash=0
 | 
					 | 
				
			||||||
            )
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        if limit == 0:
 | 
					 | 
				
			||||||
            if not _total:
 | 
					 | 
				
			||||||
                return
 | 
					 | 
				
			||||||
            # No messages, but we still need to know the total message count
 | 
					 | 
				
			||||||
            result = await self(request)
 | 
					 | 
				
			||||||
            if isinstance(result, types.messages.MessagesNotModified):
 | 
					 | 
				
			||||||
                _total[0] = result.count
 | 
					 | 
				
			||||||
            else:
 | 
					 | 
				
			||||||
                _total[0] = getattr(result, 'count', len(result.messages))
 | 
					 | 
				
			||||||
            return
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        if wait_time is None:
 | 
					 | 
				
			||||||
            wait_time = 1 if limit > 3000 else 0
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        have = 0
 | 
					 | 
				
			||||||
        last_id = 0 if reverse else float('inf')
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        # Telegram has a hard limit of 100.
 | 
					 | 
				
			||||||
        # We don't need to fetch 100 if the limit is less.
 | 
					 | 
				
			||||||
        batch_size = min(max(batch_size, 1), min(100, limit))
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        # When going in reverse we need an offset of `-limit`, but we
 | 
					 | 
				
			||||||
        # also want to respect what the user passed, so add them together.
 | 
					 | 
				
			||||||
        if reverse:
 | 
					 | 
				
			||||||
            request.add_offset -= batch_size
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        while have < limit:
 | 
					 | 
				
			||||||
            start = time.time()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            request.limit = min(limit - have, batch_size)
 | 
					 | 
				
			||||||
            if reverse and request.limit != batch_size:
 | 
					 | 
				
			||||||
                # Remember that we need -limit when going in reverse
 | 
					 | 
				
			||||||
                request.add_offset = add_offset - request.limit
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            r = await self(request)
 | 
					 | 
				
			||||||
            if _total:
 | 
					 | 
				
			||||||
                _total[0] = getattr(r, 'count', len(r.messages))
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            entities = {utils.get_peer_id(x): x
 | 
					 | 
				
			||||||
                        for x in itertools.chain(r.users, r.chats)}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            messages = reversed(r.messages) if reverse else r.messages
 | 
					 | 
				
			||||||
            for message in messages:
 | 
					 | 
				
			||||||
                if (isinstance(message, types.MessageEmpty)
 | 
					 | 
				
			||||||
                        or from_id and message.from_id != from_id):
 | 
					 | 
				
			||||||
                    continue
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
                if reverse is None:
 | 
					 | 
				
			||||||
                    pass
 | 
					 | 
				
			||||||
                elif reverse:
 | 
					 | 
				
			||||||
                    if message.id <= last_id or message.id >= max_id:
 | 
					 | 
				
			||||||
                        return
 | 
					 | 
				
			||||||
                else:
 | 
					 | 
				
			||||||
                    if message.id >= last_id or message.id <= min_id:
 | 
					 | 
				
			||||||
                        return
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
                # There has been reports that on bad connections this method
 | 
					 | 
				
			||||||
                # was returning duplicated IDs sometimes. Using ``last_id``
 | 
					 | 
				
			||||||
                # is an attempt to avoid these duplicates, since the message
 | 
					 | 
				
			||||||
                # IDs are returned in descending order (or asc if reverse).
 | 
					 | 
				
			||||||
                last_id = message.id
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
                message._finish_init(self, entities, entity)
 | 
					 | 
				
			||||||
                await yield_(message)
 | 
					 | 
				
			||||||
                have += 1
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            if len(r.messages) < request.limit:
 | 
					 | 
				
			||||||
                break
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            # Find the first message that's not empty (in some rare cases
 | 
					 | 
				
			||||||
            # it can happen that the last message is :tl:`MessageEmpty`)
 | 
					 | 
				
			||||||
            last_message = None
 | 
					 | 
				
			||||||
            messages = r.messages if reverse else reversed(r.messages)
 | 
					 | 
				
			||||||
            for m in messages:
 | 
					 | 
				
			||||||
                if not isinstance(m, types.MessageEmpty):
 | 
					 | 
				
			||||||
                    last_message = m
 | 
					 | 
				
			||||||
                    break
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            if last_message is None:
 | 
					 | 
				
			||||||
                # There are some cases where all the messages we get start
 | 
					 | 
				
			||||||
                # being empty. This can happen on migrated mega-groups if
 | 
					 | 
				
			||||||
                # the history was cleared, and we're using search. Telegram
 | 
					 | 
				
			||||||
                # acts incredibly weird sometimes. Messages are returned but
 | 
					 | 
				
			||||||
                # only "empty", not their contents. If this is the case we
 | 
					 | 
				
			||||||
                # should just give up since there won't be any new Message.
 | 
					 | 
				
			||||||
                break
 | 
					 | 
				
			||||||
            else:
 | 
					 | 
				
			||||||
                request.offset_id = last_message.id
 | 
					 | 
				
			||||||
                if isinstance(request, functions.messages.SearchRequest):
 | 
					 | 
				
			||||||
                    request.max_date = last_message.date
 | 
					 | 
				
			||||||
                else:
 | 
					 | 
				
			||||||
                    # getHistory and searchGlobal call it offset_date
 | 
					 | 
				
			||||||
                    request.offset_date = last_message.date
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
                if isinstance(request, functions.messages.SearchGlobalRequest):
 | 
					 | 
				
			||||||
                    request.offset_peer = last_message.input_chat
 | 
					 | 
				
			||||||
                elif reverse:
 | 
					 | 
				
			||||||
                    # We want to skip the one we already have
 | 
					 | 
				
			||||||
                    request.offset_id += 1
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            await asyncio.sleep(
 | 
					 | 
				
			||||||
                max(wait_time - (time.time() - start), 0), loop=self._loop)
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    async def get_messages(self, *args, **kwargs):
 | 
					    async def get_messages(self, *args, **kwargs):
 | 
				
			||||||
        """
 | 
					        """
 | 
				
			||||||
| 
						 | 
					@ -353,23 +438,23 @@ class MessageMethods(UploadMethods, ButtonMethods, MessageParseMethods):
 | 
				
			||||||
        a single `Message <telethon.tl.custom.message.Message>` will be
 | 
					        a single `Message <telethon.tl.custom.message.Message>` will be
 | 
				
			||||||
        returned for convenience instead of a list.
 | 
					        returned for convenience instead of a list.
 | 
				
			||||||
        """
 | 
					        """
 | 
				
			||||||
        total = [0]
 | 
					 | 
				
			||||||
        kwargs['_total'] = total
 | 
					 | 
				
			||||||
        if len(args) == 1 and 'limit' not in kwargs:
 | 
					        if len(args) == 1 and 'limit' not in kwargs:
 | 
				
			||||||
            if 'min_id' in kwargs and 'max_id' in kwargs:
 | 
					            if 'min_id' in kwargs and 'max_id' in kwargs:
 | 
				
			||||||
                kwargs['limit'] = None
 | 
					                kwargs['limit'] = None
 | 
				
			||||||
            else:
 | 
					            else:
 | 
				
			||||||
                kwargs['limit'] = 1
 | 
					                kwargs['limit'] = 1
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        msgs = helpers.TotalList()
 | 
					        it = self.iter_messages(*args, **kwargs)
 | 
				
			||||||
        async for x in self.iter_messages(*args, **kwargs):
 | 
					 | 
				
			||||||
            msgs.append(x)
 | 
					 | 
				
			||||||
        msgs.total = total[0]
 | 
					 | 
				
			||||||
        if 'ids' in kwargs and not utils.is_list_like(kwargs['ids']):
 | 
					 | 
				
			||||||
            # Check for empty list to handle InputMessageReplyTo
 | 
					 | 
				
			||||||
            return msgs[0] if msgs else None
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
        return msgs
 | 
					        ids = kwargs.get('ids')
 | 
				
			||||||
 | 
					        if ids and not utils.is_list_like(ids):
 | 
				
			||||||
 | 
					            async for message in it:
 | 
				
			||||||
 | 
					                return message
 | 
				
			||||||
 | 
					            else:
 | 
				
			||||||
 | 
					                # Iterator exhausted = empty, to handle InputMessageReplyTo
 | 
				
			||||||
 | 
					                return None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return await it.collect()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # endregion
 | 
					    # endregion
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -799,52 +884,3 @@ class MessageMethods(UploadMethods, ButtonMethods, MessageParseMethods):
 | 
				
			||||||
    # endregion
 | 
					    # endregion
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # endregion
 | 
					    # endregion
 | 
				
			||||||
 | 
					 | 
				
			||||||
    # region Private methods
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    @async_generator
 | 
					 | 
				
			||||||
    async def _iter_ids(self, entity, ids, total):
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        Special case for `iter_messages` when it should only fetch some IDs.
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        if total:
 | 
					 | 
				
			||||||
            total[0] = len(ids)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        from_id = None  # By default, no need to validate from_id
 | 
					 | 
				
			||||||
        if isinstance(entity, (types.InputChannel, types.InputPeerChannel)):
 | 
					 | 
				
			||||||
            try:
 | 
					 | 
				
			||||||
                r = await self(
 | 
					 | 
				
			||||||
                    functions.channels.GetMessagesRequest(entity, ids))
 | 
					 | 
				
			||||||
            except errors.MessageIdsEmptyError:
 | 
					 | 
				
			||||||
                # All IDs were invalid, use a dummy result
 | 
					 | 
				
			||||||
                r = types.messages.MessagesNotModified(len(ids))
 | 
					 | 
				
			||||||
        else:
 | 
					 | 
				
			||||||
            r = await self(functions.messages.GetMessagesRequest(ids))
 | 
					 | 
				
			||||||
            if entity:
 | 
					 | 
				
			||||||
                from_id = utils.get_peer_id(entity)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        if isinstance(r, types.messages.MessagesNotModified):
 | 
					 | 
				
			||||||
            for _ in ids:
 | 
					 | 
				
			||||||
                await yield_(None)
 | 
					 | 
				
			||||||
            return
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        entities = {utils.get_peer_id(x): x
 | 
					 | 
				
			||||||
                    for x in itertools.chain(r.users, r.chats)}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        # Telegram seems to return the messages in the order in which
 | 
					 | 
				
			||||||
        # we asked them for, so we don't need to check it ourselves,
 | 
					 | 
				
			||||||
        # unless some messages were invalid in which case Telegram
 | 
					 | 
				
			||||||
        # may decide to not send them at all.
 | 
					 | 
				
			||||||
        #
 | 
					 | 
				
			||||||
        # The passed message IDs may not belong to the desired entity
 | 
					 | 
				
			||||||
        # since the user can enter arbitrary numbers which can belong to
 | 
					 | 
				
			||||||
        # arbitrary chats. Validate these unless ``from_id is None``.
 | 
					 | 
				
			||||||
        for message in r.messages:
 | 
					 | 
				
			||||||
            if isinstance(message, types.MessageEmpty) or (
 | 
					 | 
				
			||||||
                    from_id and message.chat_id != from_id):
 | 
					 | 
				
			||||||
                await yield_(None)
 | 
					 | 
				
			||||||
            else:
 | 
					 | 
				
			||||||
                message._finish_init(self, entities, entity)
 | 
					 | 
				
			||||||
                await yield_(message)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    # endregion
 | 
					 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
							
								
								
									
										130
									
								
								telethon/requestiter.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										130
									
								
								telethon/requestiter.py
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
					@ -0,0 +1,130 @@
 | 
				
			||||||
 | 
					import abc
 | 
				
			||||||
 | 
					import asyncio
 | 
				
			||||||
 | 
					import time
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from . import helpers
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# TODO There are two types of iterators for requests.
 | 
				
			||||||
 | 
					#      One has a limit of items to retrieve, and the
 | 
				
			||||||
 | 
					#      other has a list that must be called in chunks.
 | 
				
			||||||
 | 
					#      Make classes for both here so it's easy to use.
 | 
				
			||||||
 | 
					class RequestIter(abc.ABC):
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    Helper class to deal with requests that need offsets to iterate.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    It has some facilities, such as automatically sleeping a desired
 | 
				
			||||||
 | 
					    amount of time between requests if needed (but not more).
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Can be used synchronously if the event loop is not running and
 | 
				
			||||||
 | 
					    as an asynchronous iterator otherwise.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    `limit` is the total amount of items that the iterator should return.
 | 
				
			||||||
 | 
					    This is handled on this base class, and will be always ``>= 0``.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    `left` will be reset every time the iterator is used and will indicate
 | 
				
			||||||
 | 
					    the amount of items that should be emitted left, so that subclasses can
 | 
				
			||||||
 | 
					    be more efficient and fetch only as many items as they need.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Iterators may be used with ``reversed``, and their `reverse` flag will
 | 
				
			||||||
 | 
					    be set to ``True`` if that's the case. Note that if this flag is set,
 | 
				
			||||||
 | 
					    `buffer` should be filled in reverse too.
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    def __init__(self, client, limit, *, reverse=False, wait_time=None, **kwargs):
 | 
				
			||||||
 | 
					        self.client = client
 | 
				
			||||||
 | 
					        self.reverse = reverse
 | 
				
			||||||
 | 
					        self.wait_time = wait_time
 | 
				
			||||||
 | 
					        self.kwargs = kwargs
 | 
				
			||||||
 | 
					        self.limit = max(float('inf') if limit is None else limit, 0)
 | 
				
			||||||
 | 
					        self.left = None
 | 
				
			||||||
 | 
					        self.buffer = None
 | 
				
			||||||
 | 
					        self.index = None
 | 
				
			||||||
 | 
					        self.total = None
 | 
				
			||||||
 | 
					        self.last_load = None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    async def _init(self, **kwargs):
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        Called when asynchronous initialization is necessary. All keyword
 | 
				
			||||||
 | 
					        arguments passed to `__init__` will be forwarded here, and it's
 | 
				
			||||||
 | 
					        preferable to use named arguments in the subclasses without defaults
 | 
				
			||||||
 | 
					        to avoid forgetting or misspelling any of them.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        This method may ``raise StopAsyncIteration`` if it cannot continue.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        This method may actually fill the initial buffer if it needs to.
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    async def __anext__(self):
 | 
				
			||||||
 | 
					        if self.buffer is None:
 | 
				
			||||||
 | 
					            self.buffer = []
 | 
				
			||||||
 | 
					            await self._init(**self.kwargs)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if self.left <= 0:  # <= 0 because subclasses may change it
 | 
				
			||||||
 | 
					            raise StopAsyncIteration
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if self.index == len(self.buffer):
 | 
				
			||||||
 | 
					            # asyncio will handle times <= 0 to sleep 0 seconds
 | 
				
			||||||
 | 
					            if self.wait_time:
 | 
				
			||||||
 | 
					                await asyncio.sleep(
 | 
				
			||||||
 | 
					                    self.wait_time - (time.time() - self.last_load),
 | 
				
			||||||
 | 
					                    loop=self.client.loop
 | 
				
			||||||
 | 
					                )
 | 
				
			||||||
 | 
					                self.last_load = time.time()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            self.index = 0
 | 
				
			||||||
 | 
					            self.buffer = []
 | 
				
			||||||
 | 
					            if await self._load_next_chunk():
 | 
				
			||||||
 | 
					                self.left = len(self.buffer)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if not self.buffer:
 | 
				
			||||||
 | 
					            raise StopAsyncIteration
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        result = self.buffer[self.index]
 | 
				
			||||||
 | 
					        self.left -= 1
 | 
				
			||||||
 | 
					        self.index += 1
 | 
				
			||||||
 | 
					        return result
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __aiter__(self):
 | 
				
			||||||
 | 
					        self.buffer = None
 | 
				
			||||||
 | 
					        self.index = 0
 | 
				
			||||||
 | 
					        self.last_load = 0
 | 
				
			||||||
 | 
					        self.left = self.limit
 | 
				
			||||||
 | 
					        return self
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __iter__(self):
 | 
				
			||||||
 | 
					        if self.client.loop.is_running():
 | 
				
			||||||
 | 
					            raise RuntimeError(
 | 
				
			||||||
 | 
					                'You must use "async for" if the event loop '
 | 
				
			||||||
 | 
					                'is running (i.e. you are inside an "async def")'
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return self.__aiter__()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    async def collect(self):
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        Create a `self` iterator and collect it into a `TotalList`
 | 
				
			||||||
 | 
					        (a normal list with a `.total` attribute).
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        result = helpers.TotalList()
 | 
				
			||||||
 | 
					        async for message in self:
 | 
				
			||||||
 | 
					            result.append(message)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        result.total = self.total
 | 
				
			||||||
 | 
					        return result
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @abc.abstractmethod
 | 
				
			||||||
 | 
					    async def _load_next_chunk(self):
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        Called when the next chunk is necessary.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        It should extend the `buffer` with new items.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        It should return ``True`` if it's the last chunk,
 | 
				
			||||||
 | 
					        after which moment the method won't be called again
 | 
				
			||||||
 | 
					        during the same iteration.
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        raise NotImplementedError
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __reversed__(self):
 | 
				
			||||||
 | 
					        self.reverse = not self.reverse
 | 
				
			||||||
 | 
					        return self  # __aiter__ will be called after, too
 | 
				
			||||||
| 
						 | 
					@ -14,8 +14,6 @@ import asyncio
 | 
				
			||||||
import functools
 | 
					import functools
 | 
				
			||||||
import inspect
 | 
					import inspect
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from async_generator import isasyncgenfunction
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
from .client.telegramclient import TelegramClient
 | 
					from .client.telegramclient import TelegramClient
 | 
				
			||||||
from .tl.custom import (
 | 
					from .tl.custom import (
 | 
				
			||||||
    Draft, Dialog, MessageButton, Forward, Message, InlineResult, Conversation
 | 
					    Draft, Dialog, MessageButton, Forward, Message, InlineResult, Conversation
 | 
				
			||||||
| 
						 | 
					@ -24,22 +22,7 @@ from .tl.custom.chatgetter import ChatGetter
 | 
				
			||||||
from .tl.custom.sendergetter import SenderGetter
 | 
					from .tl.custom.sendergetter import SenderGetter
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class _SyncGen:
 | 
					def _syncify_wrap(t, method_name):
 | 
				
			||||||
    def __init__(self, gen):
 | 
					 | 
				
			||||||
        self.gen = gen
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def __iter__(self):
 | 
					 | 
				
			||||||
        return self
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def __next__(self):
 | 
					 | 
				
			||||||
        try:
 | 
					 | 
				
			||||||
            return asyncio.get_event_loop() \
 | 
					 | 
				
			||||||
                .run_until_complete(self.gen.__anext__())
 | 
					 | 
				
			||||||
        except StopAsyncIteration:
 | 
					 | 
				
			||||||
            raise StopIteration from None
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
def _syncify_wrap(t, method_name, gen):
 | 
					 | 
				
			||||||
    method = getattr(t, method_name)
 | 
					    method = getattr(t, method_name)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @functools.wraps(method)
 | 
					    @functools.wraps(method)
 | 
				
			||||||
| 
						 | 
					@ -48,8 +31,6 @@ def _syncify_wrap(t, method_name, gen):
 | 
				
			||||||
        loop = asyncio.get_event_loop()
 | 
					        loop = asyncio.get_event_loop()
 | 
				
			||||||
        if loop.is_running():
 | 
					        if loop.is_running():
 | 
				
			||||||
            return coro
 | 
					            return coro
 | 
				
			||||||
        elif gen:
 | 
					 | 
				
			||||||
            return _SyncGen(coro)
 | 
					 | 
				
			||||||
        else:
 | 
					        else:
 | 
				
			||||||
            return loop.run_until_complete(coro)
 | 
					            return loop.run_until_complete(coro)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -64,13 +45,14 @@ def syncify(*types):
 | 
				
			||||||
    into synchronous, which return either the coroutine or the result
 | 
					    into synchronous, which return either the coroutine or the result
 | 
				
			||||||
    based on whether ``asyncio's`` event loop is running.
 | 
					    based on whether ``asyncio's`` event loop is running.
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
 | 
					    # Our asynchronous generators all are `RequestIter`, which already
 | 
				
			||||||
 | 
					    # provide a synchronous iterator variant, so we don't need to worry
 | 
				
			||||||
 | 
					    # about asyncgenfunction's here.
 | 
				
			||||||
    for t in types:
 | 
					    for t in types:
 | 
				
			||||||
        for name in dir(t):
 | 
					        for name in dir(t):
 | 
				
			||||||
            if not name.startswith('_') or name == '__call__':
 | 
					            if not name.startswith('_') or name == '__call__':
 | 
				
			||||||
                if inspect.iscoroutinefunction(getattr(t, name)):
 | 
					                if inspect.iscoroutinefunction(getattr(t, name)):
 | 
				
			||||||
                    _syncify_wrap(t, name, gen=False)
 | 
					                    _syncify_wrap(t, name)
 | 
				
			||||||
                elif isasyncgenfunction(getattr(t, name)):
 | 
					 | 
				
			||||||
                    _syncify_wrap(t, name, gen=True)
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
syncify(TelegramClient, Draft, Dialog, MessageButton,
 | 
					syncify(TelegramClient, Draft, Dialog, MessageButton,
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue
	
	Block a user