diff --git a/telethon/sync.py b/telethon/sync.py index d47500f3..925b2046 100644 --- a/telethon/sync.py +++ b/telethon/sync.py @@ -25,8 +25,7 @@ from .tl.custom.sendergetter import SenderGetter class _SyncGen: - def __init__(self, loop, gen): - self.loop = loop + def __init__(self, gen): self.gen = gen def __iter__(self): @@ -34,21 +33,25 @@ class _SyncGen: def __next__(self): try: - return self.loop.run_until_complete(self.gen.__anext__()) + return asyncio.get_event_loop() \ + .run_until_complete(self.gen.__anext__()) except StopAsyncIteration: raise StopIteration from None -def _syncify_wrap(t, method_name, syncifier): +def _syncify_wrap(t, method_name, gen): method = getattr(t, method_name) @functools.wraps(method) def syncified(*args, **kwargs): coro = method(*args, **kwargs) - return ( - coro if asyncio.get_event_loop().is_running() - else syncifier(coro) - ) + loop = asyncio.get_event_loop() + if loop.is_running(): + return coro + elif gen: + return _SyncGen(coro) + else: + return loop.run_until_complete(coro) # Save an accessible reference to the original method setattr(syncified, '__tl.sync', method) @@ -61,14 +64,13 @@ def syncify(*types): into synchronous, which return either the coroutine or the result based on whether ``asyncio's`` event loop is running. """ - loop = asyncio.get_event_loop() for t in types: for name in dir(t): if not name.startswith('_') or name == '__call__': if inspect.iscoroutinefunction(getattr(t, name)): - _syncify_wrap(t, name, loop.run_until_complete) + _syncify_wrap(t, name, gen=False) elif isasyncgenfunction(getattr(t, name)): - _syncify_wrap(t, name, functools.partial(_SyncGen, loop)) + _syncify_wrap(t, name, gen=True) syncify(TelegramClient, Draft, Dialog, MessageButton,