diff --git a/telethon/full_sync.py b/telethon/full_sync.py index 06aa73e2..94bd6e93 100644 --- a/telethon/full_sync.py +++ b/telethon/full_sync.py @@ -9,7 +9,7 @@ import asyncio import functools import inspect import threading -from concurrent import futures +from concurrent.futures import Future, ThreadPoolExecutor from async_generator import isasyncgenfunction @@ -30,25 +30,11 @@ async def _proxy_future(af, cf): def _sync_result(loop, x): - f = futures.Future() + f = Future() loop.call_soon_threadsafe(asyncio.ensure_future, _proxy_future(x, f)) return f.result() -def _syncify_coro(t, method_name, loop, thread_name): - method = getattr(t, method_name) - - @functools.wraps(method) - def syncified(*args, **kwargs): - coro = method(*args, **kwargs) - return ( - coro if threading.current_thread().name == thread_name - else _sync_result(loop, coro) - ) - - setattr(t, method_name, syncified) - - class _SyncGen: def __init__(self, loop, gen): self.loop = loop @@ -64,7 +50,7 @@ class _SyncGen: raise StopIteration from None -def _syncify_gen(t, method_name, loop, thread_name): +def _syncify_wrap(t, method_name, loop, thread_name, syncifier=_sync_result): method = getattr(t, method_name) @functools.wraps(method) @@ -72,7 +58,7 @@ def _syncify_gen(t, method_name, loop, thread_name): coro = method(*args, **kwargs) return ( coro if threading.current_thread().name == thread_name - else _SyncGen(loop, coro) + else syncifier(loop, coro) ) setattr(t, method_name, syncified) @@ -83,9 +69,9 @@ def _syncify(*types, loop, thread_name): for method_name in dir(t): if not method_name.startswith('_') or method_name == '__call__': if inspect.iscoroutinefunction(getattr(t, method_name)): - _syncify_coro(t, method_name, loop, thread_name) + _syncify_wrap(t, method_name, loop, thread_name, _sync_result) elif isasyncgenfunction(getattr(t, method_name)): - _syncify_gen(t, method_name, loop, thread_name) + _syncify_wrap(t, method_name, loop, thread_name, _SyncGen) __asyncthread = None @@ -94,42 +80,38 @@ __asyncthread = None def enable(loop=None, thread_name="__telethon_async_thread__"): global __asyncthread if __asyncthread is not None: - raise ValueError("full_sync has already been enabled.") + raise RuntimeError("full_sync can only be enabled once") if not loop: loop = asyncio.get_event_loop() - old_init = TelegramClient.__init__ - @functools.wraps(old_init) - def new_init(*args, **kwargs): - kwargs['loop'] = loop - return old_init(*args, **kwargs) - - TelegramClient.__init__ = new_init + TelegramClient.__init__ = functools.partialmethod(TelegramClient.__init__, + loop=loop) _syncify(TelegramClient, Draft, Dialog, MessageButton, ChatGetter, SenderGetter, Forward, Message, InlineResult, Conversation, loop=loop, thread_name=thread_name) - _syncify_coro(TelegramClient, "start", loop, thread_name) + _syncify_wrap(TelegramClient, "start", loop, thread_name) old_add_event_handler = TelegramClient.add_event_handler old_remove_event_handler = TelegramClient.remove_event_handler proxied_event_handlers = {} @functools.wraps(old_add_event_handler) - def add_proxied_event_handler(self, callback, event_type=None): + def add_proxied_event_handler(self, callback, *args, **kwargs): async def _proxy(event): h_t = threading.Thread(target=callback, args=(event,)) h_t.start() proxied_event_handlers[callback] = _proxy - return old_add_event_handler(self, _proxy, event_type) + args = (self, callback, *args) + return old_add_event_handler(*args, **kwargs) @functools.wraps(old_remove_event_handler) - def remove_proxied_event_handler(self, callback, event_type=None): - return old_remove_event_handler( - self, proxied_event_handlers.get(callback, callback), event_type) + def remove_proxied_event_handler(self, callback, *args, **kwargs): + args = (self, proxied_event_handlers.get(callback, callback), *args) + return old_remove_event_handler(*args, **kwargs) TelegramClient.add_event_handler = add_proxied_event_handler TelegramClient.remove_event_handler = remove_proxied_event_handler @@ -143,7 +125,8 @@ def enable(loop=None, thread_name="__telethon_async_thread__"): asyncio.set_event_loop(loop) loop.run_forever() - __asyncthread = threading.Thread(target=start, name=thread_name) + __asyncthread = threading.Thread(target=start, name=thread_name, + daemon=True) __asyncthread.start() __asyncthread.loop = loop return __asyncthread @@ -152,5 +135,5 @@ def enable(loop=None, thread_name="__telethon_async_thread__"): def stop(): global __asyncthread if not __asyncthread: - raise ValueError("Can't find asyncio thread") + raise RuntimeError("Can't find asyncio thread") __asyncthread.loop.call_soon_threadsafe(__asyncthread.loop.stop)