diff --git a/telethon/network/mtprotosender.py b/telethon/network/mtprotosender.py index d15216c1..23bf3615 100644 --- a/telethon/network/mtprotosender.py +++ b/telethon/network/mtprotosender.py @@ -56,7 +56,7 @@ class MTProtoSender: self._recv_loop_handle = None # Sending something shouldn't block - self._send_queue = asyncio.Queue() + self._send_queue = _ContainerQueue() # Telegram responds to messages out of order. Keep # {id: Message} to set their Future result upon arrival. @@ -168,16 +168,9 @@ class MTProtoSender: Besides `connect`, only this method ever sends data. """ while self._user_connected: - messages = [await self._send_queue.get()] - while not self._send_queue.empty(): - messages.append(self._send_queue.get_nowait()) - - # TODO if _send_queue has a container and we wrap it inside - # another then that will not work. - if len(messages) == 1: - message = messages[0] - else: - message = TLMessage(self.session, MessageContainer(messages)) + message = await self._send_queue.get() + if isinstance(message, list): + message = TLMessage(self.session, MessageContainer(message)) self._pending_messages[message.msg_id] = message self._pending_containers.append(message) @@ -432,3 +425,31 @@ class MTProtoSender: msg = self._pending_messages.pop(msg_id, None) if msg: msg.future.set_result(salts) + + +class _ContainerQueue(asyncio.Queue): + """ + An asyncio queue that's aware of `MessageContainer` instances. + + The `get` method returns either a single `TLMessage` or a list + of them that should be turned into a new `MessageContainer`. + + Instances of this class can be replaced with the simpler + ``asyncio.Queue`` when needed for testing purposes, and + a list won't be returned in said case. + """ + async def get(self): + result = await super().get() + if self.empty() or isinstance(result.request, MessageContainer): + return result + + result = [result] + while not self.empty(): + item = self.get_nowait() + if isinstance(item.request, MessageContainer): + await self.put(item) + break + else: + result.append(item) + + return result