diff --git a/telethon/full_sync.py b/telethon/full_sync.py index a8c06799..624e92f0 100644 --- a/telethon/full_sync.py +++ b/telethon/full_sync.py @@ -67,12 +67,14 @@ def _syncify_wrap(t, method_name, loop, thread_ident, syncifier=_sync_result): def _syncify(*types, loop, thread_ident): for t in types: - for method_name in dir(t): - if not method_name.startswith('_') or method_name == '__call__': - if inspect.iscoroutinefunction(getattr(t, method_name)): - _syncify_wrap(t, method_name, loop, thread_ident, _sync_result) - elif isasyncgenfunction(getattr(t, method_name)): - _syncify_wrap(t, method_name, loop, thread_ident, _SyncGen) + for name in dir(t): + if not name.startswith('_') or name == '__call__': + meth = getattr(t, name) + meth = getattr(meth, '__tl.sync', meth) + if inspect.iscoroutinefunction(meth): + _syncify_wrap(t, name, loop, thread_ident) + elif isasyncgenfunction(meth): + _syncify_wrap(t, name, loop, thread_ident, _SyncGen) __asyncthread = None diff --git a/telethon/sync.py b/telethon/sync.py index 15bc5859..d47500f3 100644 --- a/telethon/sync.py +++ b/telethon/sync.py @@ -24,20 +24,6 @@ from .tl.custom.chatgetter import ChatGetter from .tl.custom.sendergetter import SenderGetter -def _syncify_coro(t, method_name): - method = getattr(t, method_name) - - @functools.wraps(method) - def syncified(*args, **kwargs): - coro = method(*args, **kwargs) - return ( - coro if asyncio.get_event_loop().is_running() - else asyncio.get_event_loop().run_until_complete(coro) - ) - - setattr(t, method_name, syncified) - - class _SyncGen: def __init__(self, loop, gen): self.loop = loop @@ -53,7 +39,7 @@ class _SyncGen: raise StopIteration from None -def _syncify_gen(t, method_name): +def _syncify_wrap(t, method_name, syncifier): method = getattr(t, method_name) @functools.wraps(method) @@ -61,9 +47,11 @@ def _syncify_gen(t, method_name): coro = method(*args, **kwargs) return ( coro if asyncio.get_event_loop().is_running() - else _SyncGen(asyncio.get_event_loop(), coro) + else syncifier(coro) ) + # Save an accessible reference to the original method + setattr(syncified, '__tl.sync', method) setattr(t, method_name, syncified) @@ -73,13 +61,14 @@ def syncify(*types): into synchronous, which return either the coroutine or the result based on whether ``asyncio's`` event loop is running. """ + loop = asyncio.get_event_loop() for t in types: - for method_name in dir(t): - if not method_name.startswith('_') or method_name == '__call__': - if inspect.iscoroutinefunction(getattr(t, method_name)): - _syncify_coro(t, method_name) - elif isasyncgenfunction(getattr(t, method_name)): - _syncify_gen(t, method_name) + for name in dir(t): + if not name.startswith('_') or name == '__call__': + if inspect.iscoroutinefunction(getattr(t, name)): + _syncify_wrap(t, name, loop.run_until_complete) + elif isasyncgenfunction(getattr(t, name)): + _syncify_wrap(t, name, functools.partial(_SyncGen, loop)) syncify(TelegramClient, Draft, Dialog, MessageButton,