diff --git a/telethon/full_sync.py b/telethon/full_sync.py index 94bd6e93..c2f1a87f 100644 --- a/telethon/full_sync.py +++ b/telethon/full_sync.py @@ -50,48 +50,60 @@ class _SyncGen: raise StopIteration from None -def _syncify_wrap(t, method_name, loop, thread_name, syncifier=_sync_result): +def _syncify_wrap(t, method_name, loop, thread_ident, syncifier=_sync_result): method = getattr(t, method_name) @functools.wraps(method) def syncified(*args, **kwargs): coro = method(*args, **kwargs) return ( - coro if threading.current_thread().name == thread_name + coro if threading.get_ident() == thread_ident else syncifier(loop, coro) ) setattr(t, method_name, syncified) -def _syncify(*types, loop, thread_name): +def _syncify(*types, loop, thread_ident): for t in types: for method_name in dir(t): if not method_name.startswith('_') or method_name == '__call__': if inspect.iscoroutinefunction(getattr(t, method_name)): - _syncify_wrap(t, method_name, loop, thread_name, _sync_result) + _syncify_wrap(t, method_name, loop, thread_ident, _sync_result) elif isasyncgenfunction(getattr(t, method_name)): - _syncify_wrap(t, method_name, loop, thread_name, _SyncGen) + _syncify_wrap(t, method_name, loop, thread_ident, _SyncGen) __asyncthread = None -def enable(loop=None, thread_name="__telethon_async_thread__"): +def enable(loop=None, executor=None, max_workers=1): global __asyncthread if __asyncthread is not None: raise RuntimeError("full_sync can only be enabled once") if not loop: loop = asyncio.get_event_loop() + if not executor: + executor = ThreadPoolExecutor(max_workers=max_workers) + + def start(): + asyncio.set_event_loop(loop) + loop.run_forever() + + __asyncthread = threading.Thread(target=start, name="__telethon_async_thread__", + daemon=True) + __asyncthread.start() + __asyncthread.loop = loop + __asyncthread.executor = executor TelegramClient.__init__ = functools.partialmethod(TelegramClient.__init__, loop=loop) _syncify(TelegramClient, Draft, Dialog, MessageButton, ChatGetter, SenderGetter, Forward, Message, InlineResult, Conversation, - loop=loop, thread_name=thread_name) - _syncify_wrap(TelegramClient, "start", loop, thread_name) + loop=loop, thread_ident=__asyncthread.ident) + _syncify_wrap(TelegramClient, "start", loop, __asyncthread.ident) old_add_event_handler = TelegramClient.add_event_handler old_remove_event_handler = TelegramClient.remove_event_handler @@ -99,13 +111,12 @@ def enable(loop=None, thread_name="__telethon_async_thread__"): @functools.wraps(old_add_event_handler) def add_proxied_event_handler(self, callback, *args, **kwargs): - async def _proxy(event): - h_t = threading.Thread(target=callback, args=(event,)) - h_t.start() + def _proxy(*pargs, **pkwargs): + executor.submit(callback, *pargs, **pkwargs) proxied_event_handlers[callback] = _proxy - args = (self, callback, *args) + args = (self, _proxy, *args) return old_add_event_handler(*args, **kwargs) @functools.wraps(old_remove_event_handler) @@ -121,14 +132,6 @@ def enable(loop=None, thread_name="__telethon_async_thread__"): TelegramClient.run_until_disconnected = run_until_disconnected - def start(): - asyncio.set_event_loop(loop) - loop.run_forever() - - __asyncthread = threading.Thread(target=start, name=thread_name, - daemon=True) - __asyncthread.start() - __asyncthread.loop = loop return __asyncthread @@ -137,3 +140,4 @@ def stop(): if not __asyncthread: raise RuntimeError("Can't find asyncio thread") __asyncthread.loop.call_soon_threadsafe(__asyncthread.loop.stop) + __asyncthread.executor.shutdown()