Use ident instead of thread name and use ThreadPoolExecutor

This commit is contained in:
Tulir Asokan 2018-10-02 12:56:25 +03:00
parent ac7bc5d68a
commit 1285f11405

View File

@ -50,48 +50,60 @@ class _SyncGen:
raise StopIteration from None raise StopIteration from None
def _syncify_wrap(t, method_name, loop, thread_name, syncifier=_sync_result): def _syncify_wrap(t, method_name, loop, thread_ident, syncifier=_sync_result):
method = getattr(t, method_name) method = getattr(t, method_name)
@functools.wraps(method) @functools.wraps(method)
def syncified(*args, **kwargs): def syncified(*args, **kwargs):
coro = method(*args, **kwargs) coro = method(*args, **kwargs)
return ( return (
coro if threading.current_thread().name == thread_name coro if threading.get_ident() == thread_ident
else syncifier(loop, coro) else syncifier(loop, coro)
) )
setattr(t, method_name, syncified) setattr(t, method_name, syncified)
def _syncify(*types, loop, thread_name): def _syncify(*types, loop, thread_ident):
for t in types: for t in types:
for method_name in dir(t): for method_name in dir(t):
if not method_name.startswith('_') or method_name == '__call__': if not method_name.startswith('_') or method_name == '__call__':
if inspect.iscoroutinefunction(getattr(t, method_name)): if inspect.iscoroutinefunction(getattr(t, method_name)):
_syncify_wrap(t, method_name, loop, thread_name, _sync_result) _syncify_wrap(t, method_name, loop, thread_ident, _sync_result)
elif isasyncgenfunction(getattr(t, method_name)): elif isasyncgenfunction(getattr(t, method_name)):
_syncify_wrap(t, method_name, loop, thread_name, _SyncGen) _syncify_wrap(t, method_name, loop, thread_ident, _SyncGen)
__asyncthread = None __asyncthread = None
def enable(loop=None, thread_name="__telethon_async_thread__"): def enable(loop=None, executor=None, max_workers=1):
global __asyncthread global __asyncthread
if __asyncthread is not None: if __asyncthread is not None:
raise RuntimeError("full_sync can only be enabled once") raise RuntimeError("full_sync can only be enabled once")
if not loop: if not loop:
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
if not executor:
executor = ThreadPoolExecutor(max_workers=max_workers)
def start():
asyncio.set_event_loop(loop)
loop.run_forever()
__asyncthread = threading.Thread(target=start, name="__telethon_async_thread__",
daemon=True)
__asyncthread.start()
__asyncthread.loop = loop
__asyncthread.executor = executor
TelegramClient.__init__ = functools.partialmethod(TelegramClient.__init__, TelegramClient.__init__ = functools.partialmethod(TelegramClient.__init__,
loop=loop) loop=loop)
_syncify(TelegramClient, Draft, Dialog, MessageButton, ChatGetter, _syncify(TelegramClient, Draft, Dialog, MessageButton, ChatGetter,
SenderGetter, Forward, Message, InlineResult, Conversation, SenderGetter, Forward, Message, InlineResult, Conversation,
loop=loop, thread_name=thread_name) loop=loop, thread_ident=__asyncthread.ident)
_syncify_wrap(TelegramClient, "start", loop, thread_name) _syncify_wrap(TelegramClient, "start", loop, __asyncthread.ident)
old_add_event_handler = TelegramClient.add_event_handler old_add_event_handler = TelegramClient.add_event_handler
old_remove_event_handler = TelegramClient.remove_event_handler old_remove_event_handler = TelegramClient.remove_event_handler
@ -99,13 +111,12 @@ def enable(loop=None, thread_name="__telethon_async_thread__"):
@functools.wraps(old_add_event_handler) @functools.wraps(old_add_event_handler)
def add_proxied_event_handler(self, callback, *args, **kwargs): def add_proxied_event_handler(self, callback, *args, **kwargs):
async def _proxy(event): def _proxy(*pargs, **pkwargs):
h_t = threading.Thread(target=callback, args=(event,)) executor.submit(callback, *pargs, **pkwargs)
h_t.start()
proxied_event_handlers[callback] = _proxy proxied_event_handlers[callback] = _proxy
args = (self, callback, *args) args = (self, _proxy, *args)
return old_add_event_handler(*args, **kwargs) return old_add_event_handler(*args, **kwargs)
@functools.wraps(old_remove_event_handler) @functools.wraps(old_remove_event_handler)
@ -121,14 +132,6 @@ def enable(loop=None, thread_name="__telethon_async_thread__"):
TelegramClient.run_until_disconnected = run_until_disconnected TelegramClient.run_until_disconnected = run_until_disconnected
def start():
asyncio.set_event_loop(loop)
loop.run_forever()
__asyncthread = threading.Thread(target=start, name=thread_name,
daemon=True)
__asyncthread.start()
__asyncthread.loop = loop
return __asyncthread return __asyncthread
@ -137,3 +140,4 @@ def stop():
if not __asyncthread: if not __asyncthread:
raise RuntimeError("Can't find asyncio thread") raise RuntimeError("Can't find asyncio thread")
__asyncthread.loop.call_soon_threadsafe(__asyncthread.loop.stop) __asyncthread.loop.call_soon_threadsafe(__asyncthread.loop.stop)
__asyncthread.executor.shutdown()