Refactor the _TakeoutClient proxy class

This commit is contained in:
Dmitry D. Chernov 2019-02-09 00:17:22 +10:00
parent 274fa72a8c
commit 3ca24d15da
2 changed files with 38 additions and 35 deletions

View File

@ -8,45 +8,48 @@ from ..tl import functions, TLRequest
class _TakeoutClient: class _TakeoutClient:
""" """
Proxy object over the client. `c` is the client, `k` it's class, Proxy object over the client.
`r` is the takeout request, and `t` is the takeout ID.
""" """
__PROXY_INTERFACE = ('__enter__', '__exit__', '__aenter__', '__aexit__')
def __init__(self, client, request): def __init__(self, client, request):
# We're a proxy object with __getattribute__overrode so we # We use the name mangling for attributes to make them inaccessible
# need to set attributes through the super class `object`. # from within the shadowed client object.
super().__setattr__('c', client) self.__client = client
super().__setattr__('k', client.__class__) self.__request = request
super().__setattr__('r', request)
super().__setattr__('t', None) # After we initialize the proxy object variables, it's necessary to
# translate to the client any write-access to our attributes.
self.__setattr__ = lambda _, name, value: \
setattr(client, name, value)
def __enter__(self): def __enter__(self):
# We also get self attributes through super() if self.__client.loop.is_running():
if super().__getattribute__('c').loop.is_running():
raise RuntimeError( raise RuntimeError(
'You must use "async with" if the event loop ' 'You must use "async with" if the event loop '
'is running (i.e. you are inside an "async def")' 'is running (i.e. you are inside an "async def")'
) )
return super().__getattribute__( return self.__client.loop.run_until_complete(self.__aenter__())
'c').loop.run_until_complete(self.__aenter__())
async def __aenter__(self): async def __aenter__(self):
# Enter/Exit behaviour is "overrode", we don't want to call start # Enter/Exit behaviour is "overrode", we don't want to call start.
cl = super().__getattribute__('c') # TODO: Request only if takeout ID isn't set.
super().__setattr__('t', (await cl(super().__getattribute__('r'))).id) self.__client.takeout_id = (await self.__client(self.__request)).id
return self return self
def __exit__(self, *args): def __exit__(self, *args):
return super().__getattribute__( return self.__client.loop.run_until_complete(self.__aexit__(*args))
'c').loop.run_until_complete(self.__aexit__(*args))
async def __aexit__(self, *args): async def __aexit__(self, *args):
super().__setattr__('t', None) # TODO: Reset only if takeout result is set.
self.__client.takeout_id = None
async def __call__(self, request, ordered=False): async def __call__(self, request, ordered=False):
takeout_id = super().__getattribute__('t') takeout_id = self.__client.takeout_id
if takeout_id is None: if takeout_id is None:
raise ValueError('Cannot call takeout methods outside of "with"') raise ValueError('Takeout mode has not been initialized '
'(are you calling outside of "with"?)')
single = not utils.is_list_like(request) single = not utils.is_list_like(request)
requests = ((request,) if single else request) requests = ((request,) if single else request)
@ -57,28 +60,27 @@ class _TakeoutClient:
await r.resolve(self, utils) await r.resolve(self, utils)
wrapped.append(functions.InvokeWithTakeoutRequest(takeout_id, r)) wrapped.append(functions.InvokeWithTakeoutRequest(takeout_id, r))
return await super().__getattribute__('c')( return await self.__client(
wrapped[0] if single else wrapped, ordered=ordered) wrapped[0] if single else wrapped, ordered=ordered)
def __getattribute__(self, name): def __getattribute__(self, name):
if name.startswith('__'): # We access class via type() because __class__ will recurse infinitely.
# We want to override special method names if name.startswith('__') and name not in type(self).__PROXY_INTERFACE:
if name == '__class__': raise AttributeError # force call of __getattr__
# See https://github.com/LonamiWebs/Telethon/issues/1103.
name = 'k' # Try to access attribute in the proxy object and check for the same
# attribute in the shadowed object (through our __getattr__) if failed.
return super().__getattribute__(name) return super().__getattribute__(name)
value = getattr(super().__getattribute__('c'), name) def __getattr__(self, name):
value = getattr(self.__client, name)
if inspect.ismethod(value): if inspect.ismethod(value):
# Emulate bound methods behaviour by partially applying # Emulate bound methods behavior by partially applying our proxy
# our proxy class as the self parameter instead of the client # class as the self parameter instead of the client.
return functools.partial( return functools.partial(
getattr(super().__getattribute__('k'), name), self) getattr(self.__client.__class__, name), self)
else:
return value
def __setattr__(self, name, value): return value
setattr(super().__getattribute__('c'), name, value)
class AccountMethods(UserMethods): class AccountMethods(UserMethods):

View File

@ -233,6 +233,7 @@ class TelegramBaseClient(abc.ABC):
# With asynchronous sessions, it would need await, # With asynchronous sessions, it would need await,
# and defeats the purpose of properties. # and defeats the purpose of properties.
self.session = session self.session = session
self.takeout_id = None # TODO: Move to session.
self.api_id = int(api_id) self.api_id = int(api_id)
self.api_hash = api_hash self.api_hash = api_hash