Reduce __enter__/__exit__ boilerplate for sync ctx managers

This commit is contained in:
Lonami Exo 2019-04-13 10:53:33 +02:00
parent badefcec48
commit 9090ede5db
5 changed files with 33 additions and 51 deletions

View File

@ -2,7 +2,7 @@ import functools
import inspect import inspect
from .users import UserMethods, _NOT_A_REQUEST from .users import UserMethods, _NOT_A_REQUEST
from .. import utils from .. import helpers, utils
from ..tl import functions, TLRequest from ..tl import functions, TLRequest
@ -31,15 +31,6 @@ class _TakeoutClient:
def success(self, value): def success(self, value):
self.__success = value self.__success = value
def __enter__(self):
if self.__client.loop.is_running():
raise RuntimeError(
'You must use "async with" if the event loop '
'is running (i.e. you are inside an "async def")'
)
return self.__client.loop.run_until_complete(self.__aenter__())
async def __aenter__(self): async def __aenter__(self):
# Enter/Exit behaviour is "overrode", we don't want to call start. # Enter/Exit behaviour is "overrode", we don't want to call start.
client = self.__client client = self.__client
@ -50,9 +41,6 @@ class _TakeoutClient:
"takeout for the current session still not been finished yet.") "takeout for the current session still not been finished yet.")
return self return self
def __exit__(self, *args):
return self.__client.loop.run_until_complete(self.__aexit__(*args))
async def __aexit__(self, exc_type, exc_value, traceback): async def __aexit__(self, exc_type, exc_value, traceback):
if self.__success is None and self.__finalize: if self.__success is None and self.__finalize:
self.__success = exc_type is None self.__success = exc_type is None
@ -64,6 +52,9 @@ class _TakeoutClient:
raise ValueError("Failed to finish the takeout.") raise ValueError("Failed to finish the takeout.")
self.session.takeout_id = None self.session.takeout_id = None
__enter__ = helpers._sync_enter
__exit__ = helpers._sync_exit
async def __call__(self, request, ordered=False): async def __call__(self, request, ordered=False):
takeout_id = self.__client.session.takeout_id takeout_id = self.__client.session.takeout_id
if takeout_id is None: if takeout_id is None:

View File

@ -536,23 +536,13 @@ class AuthMethods(MessageParseMethods, UserMethods):
# region with blocks # region with blocks
def __enter__(self):
if self._loop.is_running():
raise RuntimeError(
'You must use "async with" if the event loop '
'is running (i.e. you are inside an "async def")'
)
return self.start()
async def __aenter__(self): async def __aenter__(self):
return await self.start() return await self.start()
def __exit__(self, *args):
# No loop.run_until_complete; it's already syncified
self.disconnect()
async def __aexit__(self, *args): async def __aexit__(self, *args):
await self.disconnect() await self.disconnect()
__enter__ = helpers._sync_enter
__exit__ = helpers._sync_exit
# endregion # endregion

View File

@ -3,7 +3,7 @@ import itertools
import string import string
from .users import UserMethods from .users import UserMethods
from .. import utils from .. import helpers, utils
from ..requestiter import RequestIter from ..requestiter import RequestIter
from ..tl import types, functions, custom from ..tl import types, functions, custom
@ -69,17 +69,8 @@ class _ChatAction:
self._task = None self._task = None
def __enter__(self): __enter__ = helpers._sync_enter
if self._client.loop.is_running(): __exit__ = helpers._sync_exit
raise RuntimeError(
'You must use "async with" if the event loop '
'is running (i.e. you are inside an "async def")'
)
return self._client.loop.run_until_complete(self.__aenter__())
def __exit__(self, *args):
return self._client.loop.run_until_complete(self.__aexit__(*args))
async def _update(self): async def _update(self):
try: try:

View File

@ -102,6 +102,25 @@ async def _cancel(log, **tasks):
except Exception: except Exception:
log.exception('Unhandled exception from %s after cancel', name) log.exception('Unhandled exception from %s after cancel', name)
def _sync_enter(self):
"""
Helps to cut boilerplate on async context
managers that offer synchronous variants.
"""
if self._client.loop.is_running():
raise RuntimeError(
'You must use "async with" if the event loop '
'is running (i.e. you are inside an "async def")'
)
return self._client.loop.run_until_complete(self.__aenter__())
def _sync_exit(self, *args):
return self._client.loop.run_until_complete(self.__aexit__(*args))
# endregion # endregion
# region Cryptographic related utils # region Cryptographic related utils

View File

@ -3,7 +3,7 @@ import itertools
import time import time
from .chatgetter import ChatGetter from .chatgetter import ChatGetter
from ... import utils, errors from ... import helpers, utils, errors
# Sometimes the edits arrive very fast (within the same second). # Sometimes the edits arrive very fast (within the same second).
@ -403,15 +403,6 @@ class Conversation(ChatGetter):
else: else:
fut.cancel() fut.cancel()
def __enter__(self):
if self._client.loop.is_running():
raise RuntimeError(
'You must use "async with" if the event loop '
'is running (i.e. you are inside an "async def")'
)
return self._client.loop.run_until_complete(self.__aenter__())
async def __aenter__(self): async def __aenter__(self):
self._input_chat = \ self._input_chat = \
await self._client.get_input_entity(self._input_chat) await self._client.get_input_entity(self._input_chat)
@ -447,9 +438,6 @@ class Conversation(ChatGetter):
"""Cancels the current conversation and exits the context manager.""" """Cancels the current conversation and exits the context manager."""
raise _ConversationCancelled() raise _ConversationCancelled()
def __exit__(self, *args):
return self._client.loop.run_until_complete(self.__aexit__(*args))
async def __aexit__(self, exc_type, exc_val, exc_tb): async def __aexit__(self, exc_type, exc_val, exc_tb):
chat_id = utils.get_peer_id(self._chat_peer) chat_id = utils.get_peer_id(self._chat_peer)
if self._client._ids_in_conversations[chat_id] == 1: if self._client._ids_in_conversations[chat_id] == 1:
@ -461,6 +449,9 @@ class Conversation(ChatGetter):
self._cancel_all() self._cancel_all()
return isinstance(exc_val, _ConversationCancelled) return isinstance(exc_val, _ConversationCancelled)
__enter__ = helpers._sync_enter
__exit__ = helpers._sync_exit
class _ConversationCancelled(InterruptedError): class _ConversationCancelled(InterruptedError):
pass pass