diff --git a/telethon/full_sync.py b/telethon/full_sync.py index 624e92f0..5f72d1df 100644 --- a/telethon/full_sync.py +++ b/telethon/full_sync.py @@ -51,7 +51,8 @@ class _SyncGen: raise StopIteration from None -def _syncify_wrap(t, method_name, loop, thread_ident, syncifier=_sync_result): +def _syncify_wrap(t, method_name, loop, thread_ident, + syncifier=_sync_result, rename=None): method = getattr(t, method_name) @functools.wraps(method) @@ -62,11 +63,25 @@ def _syncify_wrap(t, method_name, loop, thread_ident, syncifier=_sync_result): else syncifier(loop, coro) ) - setattr(t, method_name, syncified) + setattr(t, rename or method_name, syncified) def _syncify(*types, loop, thread_ident): for t in types: + # __enter__ and __exit__ need special care (VERY dirty hack). + # + # Normally we want them to raise if the loop is running because + # the user can't await there, and they need the async with variant. + # + # However they check if the loop is running to raise, which it is + # with full_sync enabled, so we patch them with the async variant. + if hasattr(t, '__aenter__'): + _syncify_wrap( + t, '__aenter__', loop, thread_ident, rename='__enter__') + + _syncify_wrap( + t, '__aexit__', loop, thread_ident, rename='__exit__') + for name in dir(t): if not name.startswith('_') or name == '__call__': meth = getattr(t, name)