Add a method to cancel_all conversations (#1183)

This commit is contained in:
Lonami Exo 2019-06-03 19:41:22 +02:00
parent 690a40be77
commit 4c3e467d25
3 changed files with 24 additions and 13 deletions

View File

@ -1,5 +1,6 @@
import abc import abc
import asyncio import asyncio
import collections
import logging import logging
import platform import platform
import time import time
@ -322,8 +323,9 @@ class TelegramBaseClient(abc.ABC):
# Some further state for subclasses # Some further state for subclasses
self._event_builders = [] self._event_builders = []
self._conversations = {}
self._ids_in_conversations = {} # chat_id: count # {chat_id: {Conversation}}
self._conversations = collections.defaultdict(set)
# Default parse mode # Default parse mode
self._parse_mode = markdown self._parse_mode = markdown

View File

@ -382,8 +382,8 @@ class UpdateMethods(UserMethods):
await self._get_difference(update, channel_id, pts_date) await self._get_difference(update, channel_id, pts_date)
built = EventBuilderDict(self, update) built = EventBuilderDict(self, update)
if self._conversations: for conv_set in self._conversations.values():
for conv in self._conversations.values(): for conv in conv_set:
ev = built[events.NewMessage] ev = built[events.NewMessage]
if ev: if ev:
conv._on_new_message(ev) conv._on_new_message(ev)

View File

@ -394,12 +394,11 @@ class Conversation(ChatGetter):
# Make sure we're the only conversation in this chat if it's exclusive # Make sure we're the only conversation in this chat if it's exclusive
chat_id = utils.get_peer_id(self._chat_peer) chat_id = utils.get_peer_id(self._chat_peer)
count = self._client._ids_in_conversations.get(chat_id, 0) conv_set = self._client._conversations[chat_id]
if self._exclusive and count: if self._exclusive and conv_set:
raise errors.AlreadyInConversationError() raise errors.AlreadyInConversationError()
self._client._ids_in_conversations[chat_id] = count + 1 conv_set.add(self)
self._client._conversations[self._id] = self
self._last_outgoing = 0 self._last_outgoing = 0
self._last_incoming = 0 self._last_incoming = 0
@ -426,14 +425,24 @@ class Conversation(ChatGetter):
""" """
self._cancel_all() self._cancel_all()
async def cancel_all(self):
"""
Calls `cancel` on *all* conversations in this chat.
Note that you should ``await`` this method, since it's meant to be
used outside of a context manager, and it needs to resolve the chat.
"""
chat_id = await self._client.get_peer_id(self._input_chat)
for conv in self._client._conversations[chat_id]:
conv.cancel()
async def __aexit__(self, exc_type, exc_val, exc_tb): async def __aexit__(self, exc_type, exc_val, exc_tb):
chat_id = utils.get_peer_id(self._chat_peer) chat_id = utils.get_peer_id(self._chat_peer)
if self._client._ids_in_conversations[chat_id] == 1: conv_set = self._client._conversations[chat_id]
del self._client._ids_in_conversations[chat_id] conv_set.discard(self)
else: if not conv_set:
self._client._ids_in_conversations[chat_id] -= 1 del self._client._conversations[chat_id]
del self._client._conversations[self._id]
self._cancel_all() self._cancel_all()
__enter__ = helpers._sync_enter __enter__ = helpers._sync_enter