Reduce __enter__/__exit__ boilerplate for sync ctx managers

This commit is contained in:
Lonami Exo 2019-04-13 10:53:33 +02:00
parent badefcec48
commit 9090ede5db
5 changed files with 33 additions and 51 deletions

View File

@ -2,7 +2,7 @@ import functools
import inspect
from .users import UserMethods, _NOT_A_REQUEST
from .. import utils
from .. import helpers, utils
from ..tl import functions, TLRequest
@ -31,15 +31,6 @@ class _TakeoutClient:
def success(self, value):
self.__success = value
def __enter__(self):
if self.__client.loop.is_running():
raise RuntimeError(
'You must use "async with" if the event loop '
'is running (i.e. you are inside an "async def")'
)
return self.__client.loop.run_until_complete(self.__aenter__())
async def __aenter__(self):
# Enter/Exit behaviour is "overrode", we don't want to call start.
client = self.__client
@ -50,9 +41,6 @@ class _TakeoutClient:
"takeout for the current session still not been finished yet.")
return self
def __exit__(self, *args):
return self.__client.loop.run_until_complete(self.__aexit__(*args))
async def __aexit__(self, exc_type, exc_value, traceback):
if self.__success is None and self.__finalize:
self.__success = exc_type is None
@ -64,6 +52,9 @@ class _TakeoutClient:
raise ValueError("Failed to finish the takeout.")
self.session.takeout_id = None
__enter__ = helpers._sync_enter
__exit__ = helpers._sync_exit
async def __call__(self, request, ordered=False):
takeout_id = self.__client.session.takeout_id
if takeout_id is None:

View File

@ -536,23 +536,13 @@ class AuthMethods(MessageParseMethods, UserMethods):
# region with blocks
def __enter__(self):
if self._loop.is_running():
raise RuntimeError(
'You must use "async with" if the event loop '
'is running (i.e. you are inside an "async def")'
)
return self.start()
async def __aenter__(self):
return await self.start()
def __exit__(self, *args):
# No loop.run_until_complete; it's already syncified
self.disconnect()
async def __aexit__(self, *args):
await self.disconnect()
__enter__ = helpers._sync_enter
__exit__ = helpers._sync_exit
# endregion

View File

@ -3,7 +3,7 @@ import itertools
import string
from .users import UserMethods
from .. import utils
from .. import helpers, utils
from ..requestiter import RequestIter
from ..tl import types, functions, custom
@ -69,17 +69,8 @@ class _ChatAction:
self._task = None
def __enter__(self):
if self._client.loop.is_running():
raise RuntimeError(
'You must use "async with" if the event loop '
'is running (i.e. you are inside an "async def")'
)
return self._client.loop.run_until_complete(self.__aenter__())
def __exit__(self, *args):
return self._client.loop.run_until_complete(self.__aexit__(*args))
__enter__ = helpers._sync_enter
__exit__ = helpers._sync_exit
async def _update(self):
try:

View File

@ -102,6 +102,25 @@ async def _cancel(log, **tasks):
except Exception:
log.exception('Unhandled exception from %s after cancel', name)
def _sync_enter(self):
"""
Helps to cut boilerplate on async context
managers that offer synchronous variants.
"""
if self._client.loop.is_running():
raise RuntimeError(
'You must use "async with" if the event loop '
'is running (i.e. you are inside an "async def")'
)
return self._client.loop.run_until_complete(self.__aenter__())
def _sync_exit(self, *args):
return self._client.loop.run_until_complete(self.__aexit__(*args))
# endregion
# region Cryptographic related utils

View File

@ -3,7 +3,7 @@ import itertools
import time
from .chatgetter import ChatGetter
from ... import utils, errors
from ... import helpers, utils, errors
# Sometimes the edits arrive very fast (within the same second).
@ -403,15 +403,6 @@ class Conversation(ChatGetter):
else:
fut.cancel()
def __enter__(self):
if self._client.loop.is_running():
raise RuntimeError(
'You must use "async with" if the event loop '
'is running (i.e. you are inside an "async def")'
)
return self._client.loop.run_until_complete(self.__aenter__())
async def __aenter__(self):
self._input_chat = \
await self._client.get_input_entity(self._input_chat)
@ -447,9 +438,6 @@ class Conversation(ChatGetter):
"""Cancels the current conversation and exits the context manager."""
raise _ConversationCancelled()
def __exit__(self, *args):
return self._client.loop.run_until_complete(self.__aexit__(*args))
async def __aexit__(self, exc_type, exc_val, exc_tb):
chat_id = utils.get_peer_id(self._chat_peer)
if self._client._ids_in_conversations[chat_id] == 1:
@ -461,6 +449,9 @@ class Conversation(ChatGetter):
self._cancel_all()
return isinstance(exc_val, _ConversationCancelled)
__enter__ = helpers._sync_enter
__exit__ = helpers._sync_exit
class _ConversationCancelled(InterruptedError):
pass