Clean-up client's __init__ and remove entity cache

Entity cache uses are removed. It was a source of ever-growing memory
usage that has to be reworked. This affects everything that tried to
obtain an input entity, input sender or input chat (such as the
SenderGetter or calls to _get_entity_pair). Input entities need to be
reworked in any case.

Its removal also affects the automatic cache of any raw API request.

Raise last error parameter is removed, and its behaviour made default.

The connection type parameter has been removed, since users really have
no need to change it.

A few more attributes have been made private, since users should not
mess with those.
This commit is contained in:
Lonami Exo 2022-01-18 12:52:22 +01:00
parent 85a9c13129
commit f8264abb5a
18 changed files with 104 additions and 253 deletions

View File

@ -761,3 +761,13 @@ also mark read only supports single now. a list would just be max anyway.
removed max id since it's not really of much use. removed max id since it's not really of much use.
client loop has been removed. embrace implicit loop as asyncio does now client loop has been removed. embrace implicit loop as asyncio does now
renamed some client params, and made other privates
timeout -> connect_timeout
connection_retries -> connect_retries
retry_delay -> connect_retry_delay
sequential_updates is gone
connection type is gone
raise_last_call_error is now the default rather than ValueError

View File

@ -241,7 +241,7 @@ async def sign_in(
elif bot_token: elif bot_token:
request = _tl.fn.auth.ImportBotAuthorization( request = _tl.fn.auth.ImportBotAuthorization(
flags=0, bot_auth_token=bot_token, flags=0, bot_auth_token=bot_token,
api_id=self.api_id, api_hash=self.api_hash api_id=self._api_id, api_hash=self._api_hash
) )
else: else:
raise ValueError('You must provide either phone and code, password, or bot_token.') raise ValueError('You must provide either phone and code, password, or bot_token.')
@ -313,8 +313,6 @@ async def _update_session_state(self, user, save=True):
Callback called whenever the login or sign up process completes. Callback called whenever the login or sign up process completes.
Returns the input user parameter. Returns the input user parameter.
""" """
self._authorized = True
state = await self(_tl.fn.updates.GetState()) state = await self(_tl.fn.updates.GetState())
await _replace_session_state( await _replace_session_state(
self, self,
@ -332,11 +330,11 @@ async def _update_session_state(self, user, save=True):
async def _replace_session_state(self, *, save=True, **changes): async def _replace_session_state(self, *, save=True, **changes):
new = dataclasses.replace(self._session_state, **changes) new = dataclasses.replace(self._session_state, **changes)
await self.session.set_state(new) await self._session.set_state(new)
self._session_state = new self._session_state = new
if save: if save:
await self.session.save() await self._session.save()
async def send_code_request( async def send_code_request(
@ -354,7 +352,7 @@ async def send_code_request(
else: else:
try: try:
result = await self(_tl.fn.auth.SendCode( result = await self(_tl.fn.auth.SendCode(
phone, self.api_id, self.api_hash, _tl.CodeSettings())) phone, self._api_id, self._api_hash, _tl.CodeSettings()))
except errors.AuthRestartError: except errors.AuthRestartError:
return await self.send_code_request(phone) return await self.send_code_request(phone)
@ -377,7 +375,6 @@ async def log_out(self: 'TelegramClient') -> bool:
except errors.RPCError: except errors.RPCError:
return False return False
self._authorized = False
self._state_cache.reset() self._state_cache.reset()
await self.disconnect() await self.disconnect()

View File

@ -11,7 +11,7 @@ import dataclasses
from .. import version, __name__ as __base_name__, _tl from .. import version, __name__ as __base_name__, _tl
from .._crypto import rsa from .._crypto import rsa
from .._misc import markdown, entitycache, statecache, enums, helpers from .._misc import markdown, statecache, enums, helpers
from .._network import MTProtoSender, Connection, transports from .._network import MTProtoSender, Connection, transports
from .._sessions import Session, SQLiteSession, MemorySession from .._sessions import Session, SQLiteSession, MemorySession
from .._sessions.types import DataCenter, SessionState from .._sessions.types import DataCenter, SessionState
@ -71,33 +71,28 @@ def init(
api_id: int, api_id: int,
api_hash: str, api_hash: str,
*, *,
connection: 'typing.Type[Connection]' = (), # Logging.
base_logger: typing.Union[str, logging.Logger] = None,
# Connection parameters.
use_ipv6: bool = False, use_ipv6: bool = False,
proxy: typing.Union[tuple, dict] = None, proxy: typing.Union[tuple, dict] = None,
local_addr: typing.Union[str, tuple] = None, local_addr: typing.Union[str, tuple] = None,
timeout: int = 10,
request_retries: int = 5,
connection_retries: int = 5,
retry_delay: int = 1,
auto_reconnect: bool = True,
sequential_updates: bool = False,
flood_sleep_threshold: int = 60,
raise_last_call_error: bool = False,
device_model: str = None, device_model: str = None,
system_version: str = None, system_version: str = None,
app_version: str = None, app_version: str = None,
lang_code: str = 'en', lang_code: str = 'en',
system_lang_code: str = 'en', system_lang_code: str = 'en',
base_logger: typing.Union[str, logging.Logger] = None, # Nice-to-have.
receive_updates: bool = True auto_reconnect: bool = True,
connect_timeout: int = 10,
connect_retries: int = 4,
connect_retry_delay: int = 1,
request_retries: int = 4,
flood_sleep_threshold: int = 60,
# Update handling.
receive_updates: bool = True,
): ):
if not api_id or not api_hash: # Logging.
raise ValueError(
"Your API ID or Hash cannot be empty or None. "
"Refer to telethon.rtfd.io for more information.")
self._use_ipv6 = use_ipv6
if isinstance(base_logger, str): if isinstance(base_logger, str):
base_logger = logging.getLogger(base_logger) base_logger = logging.getLogger(base_logger)
elif not isinstance(base_logger, logging.Logger): elif not isinstance(base_logger, logging.Logger):
@ -112,7 +107,7 @@ def init(
self._log = _Loggers() self._log = _Loggers()
# Determine what session object we have # Sessions.
if isinstance(session, str) or session is None: if isinstance(session, str) or session is None:
try: try:
session = SQLiteSession(session) session = SQLiteSession(session)
@ -131,57 +126,38 @@ def init(
'The given session must be a str or a Session instance.' 'The given session must be a str or a Session instance.'
) )
self.flood_sleep_threshold = flood_sleep_threshold self._session = session
# In-memory copy of the session's state to avoid a roundtrip as it contains commonly-accessed values.
# TODO Use AsyncClassWrapper(session)
# ChatGetter and SenderGetter can use the in-memory _entity_cache
# to avoid network access and the need for await in session files.
#
# The session files only wants the entities to persist
# them to disk, and to save additional useful information.
# TODO Session should probably return all cached
# info of entities, not just the input versions
self.session = session
# Cache session data for convenient access
self._session_state = None self._session_state = None
self._all_dcs = None
self._state_cache = statecache.StateCache(None, self._log)
self._entity_cache = entitycache.EntityCache() # Nice-to-have.
self.api_id = int(api_id) self._request_retries = request_retries
self.api_hash = api_hash self._connect_retries = connect_retries
self._connect_retry_delay = connect_retry_delay or 0
self._connect_timeout = connect_timeout
self.flood_sleep_threshold = flood_sleep_threshold
self._flood_waited_requests = {} # prevent calls that would floodwait entirely
self._parse_mode = markdown
# Connection parameters.
if not api_id or not api_hash:
raise ValueError(
"Your API ID or Hash cannot be empty or None. "
"Refer to telethon.rtfd.io for more information.")
if local_addr is not None: if local_addr is not None:
if use_ipv6 is False and ':' in local_addr: if use_ipv6 is False and ':' in local_addr:
raise TypeError( raise TypeError('A local IPv6 address must only be used with `use_ipv6=True`.')
'A local IPv6 address must only be used with `use_ipv6=True`.'
)
elif use_ipv6 is True and ':' not in local_addr: elif use_ipv6 is True and ':' not in local_addr:
raise TypeError( raise TypeError('`use_ipv6=True` must only be used with a local IPv6 address.')
'`use_ipv6=True` must only be used with a local IPv6 address.'
)
self._raise_last_call_error = raise_last_call_error self._transport = transports.Full()
self._use_ipv6 = use_ipv6
self._request_retries = request_retries
self._connection_retries = connection_retries
self._retry_delay = retry_delay or 0
self._proxy = proxy
self._local_addr = local_addr self._local_addr = local_addr
self._timeout = timeout self._proxy = proxy
self._auto_reconnect = auto_reconnect self._auto_reconnect = auto_reconnect
self._api_id = int(api_id)
if connection == (): self._api_hash = api_hash
# For now the current default remains TCP Full; may change to be "smart" if proxies are specified
connection = enums.ConnectionMode.FULL
self._transport = {
enums.ConnectionMode.FULL: transports.Full(),
enums.ConnectionMode.INTERMEDIATE: transports.Intermediate(),
enums.ConnectionMode.ABRIDGED: transports.Abridged(),
}[enums.ConnectionMode(connection)]
init_proxy = None
# Used on connection. Capture the variables in a lambda since # Used on connection. Capture the variables in a lambda since
# exporting clients need to create this InvokeWithLayer. # exporting clients need to create this InvokeWithLayer.
@ -196,7 +172,7 @@ def init(
default_system_version = re.sub(r'-.+','',system.release) default_system_version = re.sub(r'-.+','',system.release)
self._init_request = _tl.fn.InitConnection( self._init_request = _tl.fn.InitConnection(
api_id=self.api_id, api_id=self._api_id,
device_model=device_model or default_device_model or 'Unknown', device_model=device_model or default_device_model or 'Unknown',
system_version=system_version or default_system_version or '1.0', system_version=system_version or default_system_version or '1.0',
app_version=app_version or self.__version__, app_version=app_version or self.__version__,
@ -204,65 +180,26 @@ def init(
system_lang_code=system_lang_code, system_lang_code=system_lang_code,
lang_pack='', # "langPacks are for official apps only" lang_pack='', # "langPacks are for official apps only"
query=None, query=None,
proxy=init_proxy proxy=None
) )
self._sender = MTProtoSender( self._sender = MTProtoSender(
loggers=self._log, loggers=self._log,
retries=self._connection_retries, retries=self._connect_retries,
delay=self._retry_delay, delay=self._connect_retry_delay,
auto_reconnect=self._auto_reconnect, auto_reconnect=self._auto_reconnect,
connect_timeout=self._timeout, connect_timeout=self._connect_timeout,
update_callback=self._handle_update, update_callback=self._handle_update,
auto_reconnect_callback=self._handle_auto_reconnect auto_reconnect_callback=self._handle_auto_reconnect
) )
# Remember flood-waited requests to avoid making them again # Cache ``{dc_id: (_ExportState, MTProtoSender)}`` for all borrowed senders.
self._flood_waited_requests = {}
# Cache ``{dc_id: (_ExportState, MTProtoSender)}`` for all borrowed senders
self._borrowed_senders = {} self._borrowed_senders = {}
self._borrow_sender_lock = asyncio.Lock() self._borrow_sender_lock = asyncio.Lock()
self._updates_handle = None # Update handling.
self._last_request = time.time()
self._channel_pts = {}
self._no_updates = not receive_updates self._no_updates = not receive_updates
if sequential_updates:
self._updates_queue = asyncio.Queue()
self._dispatching_updates_queue = asyncio.Event()
else:
# Use a set of pending instead of a queue so we can properly
# terminate all pending updates on disconnect.
self._updates_queue = set()
self._dispatching_updates_queue = None
self._authorized = None # None = unknown, False = no, True = yes
# Some further state for subclasses
self._event_builders = []
# Hack to workaround the fact Telegram may send album updates as
# different Updates when being sent from a different data center.
# {grouped_id: AlbumHack}
#
# FIXME: We don't bother cleaning this up because it's not really
# worth it, albums are pretty rare and this only holds them
# for a second at most.
self._albums = {}
# Default parse mode
self._parse_mode = markdown
# Some fields to easy signing in. Let {phone: hash} be
# a dictionary because the user may change their mind.
self._phone_code_hash = {}
self._phone = None
self._tos = None
# A place to store if channels are a megagroup or not (see `edit_admin`)
self._megagroup_cache = {}
def get_flood_sleep_threshold(self): def get_flood_sleep_threshold(self):
return self._flood_sleep_threshold return self._flood_sleep_threshold
@ -273,8 +210,8 @@ def set_flood_sleep_threshold(self, value):
async def connect(self: 'TelegramClient') -> None: async def connect(self: 'TelegramClient') -> None:
self._all_dcs = {dc.id: dc for dc in await self.session.get_all_dc()} self._all_dcs = {dc.id: dc for dc in await self._session.get_all_dc()}
self._session_state = await self.session.get_state() self._session_state = await self._session.get_state()
if self._session_state is None: if self._session_state is None:
try_fetch_user = False try_fetch_user = False
@ -347,7 +284,7 @@ async def connect(self: 'TelegramClient') -> None:
self._all_dcs[dc.id] = DataCenter(dc.id, ip, None, dc.port, b'') self._all_dcs[dc.id] = DataCenter(dc.id, ip, None, dc.port, b'')
for dc in self._all_dcs.values(): for dc in self._all_dcs.values():
await self.session.insert_dc(dc) await self._session.insert_dc(dc)
if try_fetch_user: if try_fetch_user:
# If there was a previous session state, but the current user ID is 0, it means we've # If there was a previous session state, but the current user ID is 0, it means we've
@ -357,7 +294,7 @@ async def connect(self: 'TelegramClient') -> None:
if me: if me:
await self._update_session_state(me, save=False) await self._update_session_state(me, save=False)
await self.session.save() await self._session.save()
self._updates_handle = asyncio.create_task(self._update_loop()) self._updates_handle = asyncio.create_task(self._update_loop())

View File

@ -2645,25 +2645,26 @@ class TelegramClient:
api_id: int, api_id: int,
api_hash: str, api_hash: str,
*, *,
connection: typing.Union[str, enums.ConnectionMode] = (), # Logging.
base_logger: typing.Union[str, logging.Logger] = None,
# Connection parameters.
use_ipv6: bool = False, use_ipv6: bool = False,
proxy: typing.Union[tuple, dict] = None, proxy: typing.Union[tuple, dict] = None,
local_addr: typing.Union[str, tuple] = None, local_addr: typing.Union[str, tuple] = None,
timeout: int = 10,
request_retries: int = 5,
connection_retries: int = 5,
retry_delay: int = 1,
auto_reconnect: bool = True,
sequential_updates: bool = False,
flood_sleep_threshold: int = 60,
raise_last_call_error: bool = False,
device_model: str = None, device_model: str = None,
system_version: str = None, system_version: str = None,
app_version: str = None, app_version: str = None,
lang_code: str = 'en', lang_code: str = 'en',
system_lang_code: str = 'en', system_lang_code: str = 'en',
base_logger: typing.Union[str, logging.Logger] = None, # Nice-to-have.
receive_updates: bool = True auto_reconnect: bool = True,
connect_timeout: int = 10,
connect_retries: int = 4,
connect_retry_delay: int = 1,
request_retries: int = 4,
flood_sleep_threshold: int = 60,
# Update handling.
receive_updates: bool = True,
): ):
telegrambaseclient.init(**locals()) telegrambaseclient.init(**locals())

View File

@ -206,7 +206,7 @@ async def _update_loop(self: 'TelegramClient'):
# Entities are not saved when they are inserted because this is a rather expensive # Entities are not saved when they are inserted because this is a rather expensive
# operation (default's sqlite3 takes ~0.1s to commit changes). Do it every minute # operation (default's sqlite3 takes ~0.1s to commit changes). Do it every minute
# instead. No-op if there's nothing new. # instead. No-op if there's nothing new.
await self.session.save() await self._session.save()
# We need to send some content-related request at least hourly # We need to send some content-related request at least hourly
# for Telegram to keep delivering updates, otherwise they will # for Telegram to keep delivering updates, otherwise they will
@ -231,33 +231,6 @@ async def _dispatch_queue_updates(self: 'TelegramClient'):
self._dispatching_updates_queue.clear() self._dispatching_updates_queue.clear()
async def _dispatch_update(self: 'TelegramClient', update, entities, others, channel_id, pts_date): async def _dispatch_update(self: 'TelegramClient', update, entities, others, channel_id, pts_date):
if entities:
rows = self._entity_cache.add(list(entities.values()))
if rows:
await self.session.insert_entities(rows)
if not self._entity_cache.ensure_cached(update):
# We could add a lock to not fetch the same pts twice if we are
# already fetching it. However this does not happen in practice,
# which makes sense, because different updates have different pts.
if self._state_cache.update(update, check_only=True):
# If the update doesn't have pts, fetching won't do anything.
# For example, UpdateUserStatus or UpdateChatUserTyping.
try:
await _get_difference(self, update, entities, channel_id, pts_date)
except OSError:
pass # We were disconnected, that's okay
except RpcError:
# There's a high chance the request fails because we lack
# the channel. Because these "happen sporadically" (#1428)
# we should be okay (no flood waits) even if more occur.
pass
except ValueError:
# There is a chance that GetFullChannel and GetDifference
# inside the _get_difference() function will end up with
# ValueError("Request was unsuccessful N time(s)") for whatever reasons.
pass
built = EventBuilderDict(self, update, entities, others) built = EventBuilderDict(self, update, entities, others)
for builder, callback in self._event_builders: for builder, callback in self._event_builders:

View File

@ -76,9 +76,6 @@ async def _call(self: 'TelegramClient', sender, request, ordered=False, flood_sl
exceptions.append(e) exceptions.append(e)
results.append(None) results.append(None)
continue continue
entities = self._entity_cache.add(result)
if entities:
await self.session.insert_entities(entities)
exceptions.append(None) exceptions.append(None)
results.append(result) results.append(result)
request_index += 1 request_index += 1
@ -88,9 +85,6 @@ async def _call(self: 'TelegramClient', sender, request, ordered=False, flood_sl
return results return results
else: else:
result = await future result = await future
entities = self._entity_cache.add(result)
if entities:
await self.session.insert_entities(entities)
return result return result
except ServerError as e: except ServerError as e:
last_error = e last_error = e
@ -129,10 +123,7 @@ async def _call(self: 'TelegramClient', sender, request, ordered=False, flood_sl
raise raise
await self._switch_dc(e.new_dc) await self._switch_dc(e.new_dc)
if self._raise_last_call_error and last_error is not None: raise last_error
raise last_error
raise ValueError('Request was unsuccessful {} time(s)'
.format(attempt))
async def get_me(self: 'TelegramClient', input_peer: bool = False) \ async def get_me(self: 'TelegramClient', input_peer: bool = False) \
@ -147,15 +138,12 @@ async def is_bot(self: 'TelegramClient') -> bool:
return self._session_state.bot if self._session_state else False return self._session_state.bot if self._session_state else False
async def is_user_authorized(self: 'TelegramClient') -> bool: async def is_user_authorized(self: 'TelegramClient') -> bool:
if self._authorized is None: try:
try: # Any request that requires authorization will work
# Any request that requires authorization will work await self(_tl.fn.updates.GetState())
await self(_tl.fn.updates.GetState()) return True
self._authorized = True except RpcError:
except RpcError: return False
self._authorized = False
return self._authorized
async def get_entity( async def get_entity(
self: 'TelegramClient', self: 'TelegramClient',
@ -236,14 +224,6 @@ async def get_input_entity(
except TypeError: except TypeError:
pass pass
# Next in priority is having a peer (or its ID) cached in-memory
try:
# 0x2d45687 == crc32(b'Peer')
if isinstance(peer, int) or peer.SUBCLASS_OF_ID == 0x2d45687:
return self._entity_cache[peer]
except (AttributeError, KeyError):
pass
# Then come known strings that take precedence # Then come known strings that take precedence
if peer in ('me', 'self'): if peer in ('me', 'self'):
return _tl.InputPeerSelf() return _tl.InputPeerSelf()
@ -254,7 +234,7 @@ async def get_input_entity(
except TypeError: except TypeError:
pass pass
else: else:
entity = await self.session.get_entity(None, peer_id) entity = await self._session.get_entity(None, peer_id)
if entity: if entity:
if entity.ty in (Entity.USER, Entity.BOT): if entity.ty in (Entity.USER, Entity.BOT):
return _tl.InputPeerUser(entity.id, entity.access_hash) return _tl.InputPeerUser(entity.id, entity.access_hash)

View File

@ -165,8 +165,7 @@ class Album(EventBuilder):
def _set_client(self, client): def _set_client(self, client):
super()._set_client(client) super()._set_client(client)
self._sender, self._input_sender = utils._get_entity_pair( self._sender, self._input_sender = utils._get_entity_pair(self.sender_id, self._entities)
self.sender_id, self._entities, client._entity_cache)
self.messages = [ self.messages = [
_custom.Message._new(client, m, self._entities, None) _custom.Message._new(client, m, self._entities, None)

View File

@ -166,8 +166,7 @@ class CallbackQuery(EventBuilder):
def _set_client(self, client): def _set_client(self, client):
super()._set_client(client) super()._set_client(client)
self._sender, self._input_sender = utils._get_entity_pair( self._sender, self._input_sender = utils._get_entity_pair(self.sender_id, self._entities)
self.sender_id, self._entities, client._entity_cache)
@property @property
def id(self): def id(self):
@ -223,13 +222,10 @@ class CallbackQuery(EventBuilder):
self._input_sender = utils.get_input_peer(self._chat) self._input_sender = utils.get_input_peer(self._chat)
if not getattr(self._input_sender, 'access_hash', True): if not getattr(self._input_sender, 'access_hash', True):
# getattr with True to handle the InputPeerSelf() case # getattr with True to handle the InputPeerSelf() case
try: m = await self.get_message()
self._input_sender = self._client._entity_cache[self._sender_id] if m:
except KeyError: self._sender = m._sender
m = await self.get_message() self._input_sender = m._input_sender
if m:
self._sender = m._sender
self._input_sender = m._input_sender
async def answer( async def answer(
self, message=None, cache_time=0, *, url=None, alert=False): self, message=None, cache_time=0, *, url=None, alert=False):

View File

@ -401,20 +401,13 @@ class ChatAction(EventBuilder):
if self._input_users is None and self._user_ids: if self._input_users is None and self._user_ids:
self._input_users = [] self._input_users = []
for user_id in self._user_ids: for user_id in self._user_ids:
# First try to get it from our entities # Try to get it from our entities
try: try:
self._input_users.append(utils.get_input_peer(self._entities[user_id])) self._input_users.append(utils.get_input_peer(self._entities[user_id]))
continue continue
except (KeyError, TypeError): except (KeyError, TypeError):
pass pass
# If missing, try from the entity cache
try:
self._input_users.append(self._client._entity_cache[user_id])
continue
except KeyError:
pass
return self._input_users or [] return self._input_users or []
async def get_input_users(self): async def get_input_users(self):

View File

@ -147,8 +147,7 @@ class EventCommon(ChatGetter, abc.ABC):
# TODO Nuke # TODO Nuke
self._client = client self._client = client
if self._chat_peer: if self._chat_peer:
self._chat, self._input_chat = utils._get_entity_pair( self._chat, self._input_chat = utils._get_entity_pair(self.chat_id, self._entities)
self.chat_id, self._entities, client._entity_cache)
else: else:
self._chat = self._input_chat = None self._chat = self._input_chat = None

View File

@ -98,8 +98,7 @@ class InlineQuery(EventBuilder):
def _set_client(self, client): def _set_client(self, client):
super()._set_client(client) super()._set_client(client)
self._sender, self._input_sender = utils._get_entity_pair( self._sender, self._input_sender = utils._get_entity_pair(self.sender_id, self._entities)
self.sender_id, self._entities, client._entity_cache)
@property @property
def id(self): def id(self):

View File

@ -94,8 +94,7 @@ class UserUpdate(EventBuilder):
def _set_client(self, client): def _set_client(self, client):
super()._set_client(client) super()._set_client(client)
self._sender, self._input_sender = utils._get_entity_pair( self._sender, self._input_sender = utils._get_entity_pair(self.sender_id, self._entities)
self.sender_id, self._entities, client._entity_cache)
@property @property
def user(self): def user(self):

View File

@ -578,20 +578,16 @@ def get_input_group_call(call):
_raise_cast_fail(call, 'InputGroupCall') _raise_cast_fail(call, 'InputGroupCall')
def _get_entity_pair(entity_id, entities, cache, def _get_entity_pair(entity_id, entities,
get_input_peer=get_input_peer): get_input_peer=get_input_peer):
""" """
Returns ``(entity, input_entity)`` for the given entity ID. Returns ``(entity, input_entity)`` for the given entity ID.
""" """
entity = entities.get(entity_id) entity = entities.get(entity_id)
try: try:
input_entity = cache[entity_id] input_entity = get_input_peer(entity)
except KeyError: except TypeError:
# KeyError is unlikely, so another TypeError won't hurt input_entity = None
try:
input_entity = get_input_peer(entity)
except TypeError:
input_entity = None
return entity, input_entity return entity, input_entity

View File

@ -64,12 +64,6 @@ class ChatGetter(abc.ABC):
Note that this might not be available if the library doesn't Note that this might not be available if the library doesn't
have enough information available. have enough information available.
""" """
if self._input_chat is None and self._chat_peer and self._client:
try:
self._input_chat = self._client._entity_cache[self._chat_peer]
except KeyError:
pass
return self._input_chat return self._input_chat
async def get_input_chat(self): async def get_input_chat(self):

View File

@ -49,12 +49,6 @@ class Draft:
""" """
Input version of the entity. Input version of the entity.
""" """
if not self._input_entity:
try:
self._input_entity = self._client._entity_cache[self._peer]
except KeyError:
pass
return self._input_entity return self._input_entity
async def get_entity(self): async def get_entity(self):

View File

@ -35,13 +35,11 @@ class Forward(ChatGetter, SenderGetter):
ty = helpers._entity_type(original.from_id) ty = helpers._entity_type(original.from_id)
if ty == helpers._EntityType.USER: if ty == helpers._EntityType.USER:
sender_id = utils.get_peer_id(original.from_id) sender_id = utils.get_peer_id(original.from_id)
sender, input_sender = utils._get_entity_pair( sender, input_sender = utils._get_entity_pair(sender_id, entities)
sender_id, entities, client._entity_cache)
elif ty in (helpers._EntityType.CHAT, helpers._EntityType.CHANNEL): elif ty in (helpers._EntityType.CHAT, helpers._EntityType.CHANNEL):
peer = original.from_id peer = original.from_id
chat, input_chat = utils._get_entity_pair( chat, input_chat = utils._get_entity_pair(utils.get_peer_id(peer), entities)
utils.get_peer_id(peer), entities, client._entity_cache)
# This call resets the client # This call resets the client
ChatGetter.__init__(self, peer, chat=chat, input_chat=input_chat) ChatGetter.__init__(self, peer, chat=chat, input_chat=input_chat)

View File

@ -435,20 +435,15 @@ class Message(ChatGetter, SenderGetter):
if self.peer_id == _tl.PeerUser(client._session_state.user_id) and not self.fwd_from: if self.peer_id == _tl.PeerUser(client._session_state.user_id) and not self.fwd_from:
self.out = True self.out = True
cache = client._entity_cache self._sender, self._input_sender = utils._get_entity_pair(self.sender_id, entities)
self._sender, self._input_sender = utils._get_entity_pair( self._chat, self._input_chat = utils._get_entity_pair(self.chat_id, entities)
self.sender_id, entities, cache)
self._chat, self._input_chat = utils._get_entity_pair(
self.chat_id, entities, cache)
if input_chat: # This has priority if input_chat: # This has priority
self._input_chat = input_chat self._input_chat = input_chat
if self.via_bot_id: if self.via_bot_id:
self._via_bot, self._via_input_bot = utils._get_entity_pair( self._via_bot, self._via_input_bot = utils._get_entity_pair(self.via_bot_id, entities)
self.via_bot_id, entities, cache)
if self.fwd_from: if self.fwd_from:
self._forward = Forward(self._client, self.fwd_from, entities) self._forward = Forward(self._client, self.fwd_from, entities)
@ -1339,10 +1334,7 @@ class Message(ChatGetter, SenderGetter):
raise ValueError('No input sender') raise ValueError('No input sender')
return bot return bot
else: else:
try: raise ValueError('No input sender') from None
return self._client._entity_cache[self.via_bot_id]
except KeyError:
raise ValueError('No input sender') from None
def _document_by_attribute(self, kind, condition=None): def _document_by_attribute(self, kind, condition=None):
""" """

View File

@ -64,12 +64,6 @@ class SenderGetter(abc.ABC):
Note that this might not be available if the library can't Note that this might not be available if the library can't
find the input chat, or if the message a broadcast on a channel. find the input chat, or if the message a broadcast on a channel.
""" """
if self._input_sender is None and self._sender_id and self._client:
try:
self._input_sender = \
self._client._entity_cache[self._sender_id]
except KeyError:
pass
return self._input_sender return self._input_sender
async def get_input_sender(self): async def get_input_sender(self):