Fix problems and simplify code

This commit is contained in:
Tulir Asokan 2018-10-02 12:22:23 +03:00
parent 362e5b01e2
commit ac7bc5d68a

View File

@ -9,7 +9,7 @@ import asyncio
import functools import functools
import inspect import inspect
import threading import threading
from concurrent import futures from concurrent.futures import Future, ThreadPoolExecutor
from async_generator import isasyncgenfunction from async_generator import isasyncgenfunction
@ -30,25 +30,11 @@ async def _proxy_future(af, cf):
def _sync_result(loop, x): def _sync_result(loop, x):
f = futures.Future() f = Future()
loop.call_soon_threadsafe(asyncio.ensure_future, _proxy_future(x, f)) loop.call_soon_threadsafe(asyncio.ensure_future, _proxy_future(x, f))
return f.result() return f.result()
def _syncify_coro(t, method_name, loop, thread_name):
method = getattr(t, method_name)
@functools.wraps(method)
def syncified(*args, **kwargs):
coro = method(*args, **kwargs)
return (
coro if threading.current_thread().name == thread_name
else _sync_result(loop, coro)
)
setattr(t, method_name, syncified)
class _SyncGen: class _SyncGen:
def __init__(self, loop, gen): def __init__(self, loop, gen):
self.loop = loop self.loop = loop
@ -64,7 +50,7 @@ class _SyncGen:
raise StopIteration from None raise StopIteration from None
def _syncify_gen(t, method_name, loop, thread_name): def _syncify_wrap(t, method_name, loop, thread_name, syncifier=_sync_result):
method = getattr(t, method_name) method = getattr(t, method_name)
@functools.wraps(method) @functools.wraps(method)
@ -72,7 +58,7 @@ def _syncify_gen(t, method_name, loop, thread_name):
coro = method(*args, **kwargs) coro = method(*args, **kwargs)
return ( return (
coro if threading.current_thread().name == thread_name coro if threading.current_thread().name == thread_name
else _SyncGen(loop, coro) else syncifier(loop, coro)
) )
setattr(t, method_name, syncified) setattr(t, method_name, syncified)
@ -83,9 +69,9 @@ def _syncify(*types, loop, thread_name):
for method_name in dir(t): for method_name in dir(t):
if not method_name.startswith('_') or method_name == '__call__': if not method_name.startswith('_') or method_name == '__call__':
if inspect.iscoroutinefunction(getattr(t, method_name)): if inspect.iscoroutinefunction(getattr(t, method_name)):
_syncify_coro(t, method_name, loop, thread_name) _syncify_wrap(t, method_name, loop, thread_name, _sync_result)
elif isasyncgenfunction(getattr(t, method_name)): elif isasyncgenfunction(getattr(t, method_name)):
_syncify_gen(t, method_name, loop, thread_name) _syncify_wrap(t, method_name, loop, thread_name, _SyncGen)
__asyncthread = None __asyncthread = None
@ -94,42 +80,38 @@ __asyncthread = None
def enable(loop=None, thread_name="__telethon_async_thread__"): def enable(loop=None, thread_name="__telethon_async_thread__"):
global __asyncthread global __asyncthread
if __asyncthread is not None: if __asyncthread is not None:
raise ValueError("full_sync has already been enabled.") raise RuntimeError("full_sync can only be enabled once")
if not loop: if not loop:
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
old_init = TelegramClient.__init__
@functools.wraps(old_init) TelegramClient.__init__ = functools.partialmethod(TelegramClient.__init__,
def new_init(*args, **kwargs): loop=loop)
kwargs['loop'] = loop
return old_init(*args, **kwargs)
TelegramClient.__init__ = new_init
_syncify(TelegramClient, Draft, Dialog, MessageButton, ChatGetter, _syncify(TelegramClient, Draft, Dialog, MessageButton, ChatGetter,
SenderGetter, Forward, Message, InlineResult, Conversation, SenderGetter, Forward, Message, InlineResult, Conversation,
loop=loop, thread_name=thread_name) loop=loop, thread_name=thread_name)
_syncify_coro(TelegramClient, "start", loop, thread_name) _syncify_wrap(TelegramClient, "start", loop, thread_name)
old_add_event_handler = TelegramClient.add_event_handler old_add_event_handler = TelegramClient.add_event_handler
old_remove_event_handler = TelegramClient.remove_event_handler old_remove_event_handler = TelegramClient.remove_event_handler
proxied_event_handlers = {} proxied_event_handlers = {}
@functools.wraps(old_add_event_handler) @functools.wraps(old_add_event_handler)
def add_proxied_event_handler(self, callback, event_type=None): def add_proxied_event_handler(self, callback, *args, **kwargs):
async def _proxy(event): async def _proxy(event):
h_t = threading.Thread(target=callback, args=(event,)) h_t = threading.Thread(target=callback, args=(event,))
h_t.start() h_t.start()
proxied_event_handlers[callback] = _proxy proxied_event_handlers[callback] = _proxy
return old_add_event_handler(self, _proxy, event_type) args = (self, callback, *args)
return old_add_event_handler(*args, **kwargs)
@functools.wraps(old_remove_event_handler) @functools.wraps(old_remove_event_handler)
def remove_proxied_event_handler(self, callback, event_type=None): def remove_proxied_event_handler(self, callback, *args, **kwargs):
return old_remove_event_handler( args = (self, proxied_event_handlers.get(callback, callback), *args)
self, proxied_event_handlers.get(callback, callback), event_type) return old_remove_event_handler(*args, **kwargs)
TelegramClient.add_event_handler = add_proxied_event_handler TelegramClient.add_event_handler = add_proxied_event_handler
TelegramClient.remove_event_handler = remove_proxied_event_handler TelegramClient.remove_event_handler = remove_proxied_event_handler
@ -143,7 +125,8 @@ def enable(loop=None, thread_name="__telethon_async_thread__"):
asyncio.set_event_loop(loop) asyncio.set_event_loop(loop)
loop.run_forever() loop.run_forever()
__asyncthread = threading.Thread(target=start, name=thread_name) __asyncthread = threading.Thread(target=start, name=thread_name,
daemon=True)
__asyncthread.start() __asyncthread.start()
__asyncthread.loop = loop __asyncthread.loop = loop
return __asyncthread return __asyncthread
@ -152,5 +135,5 @@ def enable(loop=None, thread_name="__telethon_async_thread__"):
def stop(): def stop():
global __asyncthread global __asyncthread
if not __asyncthread: if not __asyncthread:
raise ValueError("Can't find asyncio thread") raise RuntimeError("Can't find asyncio thread")
__asyncthread.loop.call_soon_threadsafe(__asyncthread.loop.stop) __asyncthread.loop.call_soon_threadsafe(__asyncthread.loop.stop)