From d87b68a75621dc3be06a0933e9b5e8d73a43d460 Mon Sep 17 00:00:00 2001 From: Lonami Exo Date: Mon, 7 Feb 2022 09:28:39 +0100 Subject: [PATCH] Fix direct mutation of objects in friendly methods --- telethon/_client/chats.py | 56 +++++++++++++++----------------- telethon/_client/dialogs.py | 14 +++++--- telethon/_client/downloads.py | 15 +++++---- telethon/_client/messageparse.py | 14 ++++---- telethon/_client/messages.py | 21 ++++++------ telethon/_client/users.py | 6 ++-- 6 files changed, 65 insertions(+), 61 deletions(-) diff --git a/telethon/_client/chats.py b/telethon/_client/chats.py index e8a95291..a9a4a45f 100644 --- a/telethon/_client/chats.py +++ b/telethon/_client/chats.py @@ -3,6 +3,7 @@ import inspect import itertools import string import typing +import dataclasses from .. import errors, _tl from .._misc import helpers, utils, requestiter, tlobject, enums, hints @@ -19,11 +20,9 @@ _MAX_PROFILE_PHOTO_CHUNK_SIZE = 100 class _ChatAction: def __init__(self, client, chat, action, *, delay, auto_cancel): self._client = client - self._chat = chat - self._action = action self._delay = delay self._auto_cancel = auto_cancel - self._request = None + self._request = _tl.fn.messages.SetTyping(chat, action) self._task = None self._running = False @@ -31,14 +30,7 @@ class _ChatAction: return self._once().__await__() async def __aenter__(self): - self._chat = await self._client.get_input_entity(self._chat) - - # Since `self._action` is passed by reference we can avoid - # recreating the request all the time and still modify - # `self._action.progress` directly in `progress`. - self._request = _tl.fn.messages.SetTyping( - self._chat, self._action) - + self._request = dataclasses.replace(self._request, peer=await self._client.get_input_entity(self._request.peer)) self._running = True self._task = asyncio.create_task(self._update()) return self @@ -55,7 +47,7 @@ class _ChatAction: self._task = None async def _once(self): - self._chat = await self._client.get_input_entity(self._chat) + self._request = dataclasses.replace(self._request, peer=await self._client.get_input_entity(self._request.peer)) await self._client(_tl.fn.messages.SetTyping(self._chat, self._action)) async def _update(self): @@ -93,8 +85,11 @@ class _ChatAction: }[enums.Action(action)] def progress(self, current, total): - if hasattr(self._action, 'progress'): - self._action.progress = 100 * round(current / total) + if hasattr(self._request.action, 'progress'): + self._request = dataclasses.replace( + self._request, + action=dataclasses.replace(self._request.action, progress=100 * round(current / total)) + ) class _ParticipantsIter(requestiter.RequestIter): @@ -190,8 +185,8 @@ class _ParticipantsIter(requestiter.RequestIter): # Most people won't care about getting exactly 12,345 # members so it doesn't really matter not to be 100% # precise with being out of the offset/limit here. - self.request.limit = min( - self.limit - self.request.offset, _MAX_PARTICIPANTS_CHUNK_SIZE) + self.request = dataclasses.replace(self.request, limit=min( + self.limit - self.request.offset, _MAX_PARTICIPANTS_CHUNK_SIZE)) if self.request.offset > self.limit: return True @@ -199,7 +194,7 @@ class _ParticipantsIter(requestiter.RequestIter): participants = await self.client(self.request) self.total = participants.count - self.request.offset += len(participants.participants) + self.request = dataclasses.replace(self.request, offset=self.request.offset + len(participants.participants)) users = {user.id: user for user in participants.users} for participant in participants.participants: if isinstance(participant, _tl.ChannelParticipantBanned): @@ -253,20 +248,20 @@ class _AdminLogIter(requestiter.RequestIter): ) async def _load_next_chunk(self): - self.request.limit = min(self.left, _MAX_ADMIN_LOG_CHUNK_SIZE) + self.request = dataclasses.replace(self.request, limit=min(self.left, _MAX_ADMIN_LOG_CHUNK_SIZE)) r = await self.client(self.request) entities = {utils.get_peer_id(x): x for x in itertools.chain(r.users, r.chats)} - self.request.max_id = min((e.id for e in r.events), default=0) + self.request = dataclasses.replace(self.request, max_id=min((e.id for e in r.events), default=0)) for ev in r.events: if isinstance(ev.action, _tl.ChannelAdminLogEventActionEditMessage): - ev.action.prev_message = _custom.Message._new( - self.client, ev.action.prev_message, entities, self.entity) - - ev.action.new_message = _custom.Message._new( - self.client, ev.action.new_message, entities, self.entity) + ev = dataclasses.replace(ev, action=dataclasses.replace( + ev.action, + prev_message=_custom.Message._new(self.client, ev.action.prev_message, entities, self.entity), + new_message=_custom.Message._new(self.client, ev.action.new_message, entities, self.entity) + )) elif isinstance(ev.action, _tl.ChannelAdminLogEventActionDeleteMessage): @@ -308,7 +303,7 @@ class _ProfilePhotoIter(requestiter.RequestIter): ) if self.limit == 0: - self.request.limit = 1 + self.request = dataclasses.replace(self.request, limit=1) result = await self.client(self.request) if isinstance(result, _tl.photos.Photos): self.total = len(result.photos) @@ -319,7 +314,7 @@ class _ProfilePhotoIter(requestiter.RequestIter): self.total = getattr(result, 'count', None) async def _load_next_chunk(self): - self.request.limit = min(self.left, _MAX_PROFILE_PHOTO_CHUNK_SIZE) + self.request = dataclasses.replace(self.request, limit=min(self.left, _MAX_PROFILE_PHOTO_CHUNK_SIZE)) result = await self.client(self.request) if isinstance(result, _tl.photos.Photos): @@ -338,7 +333,7 @@ class _ProfilePhotoIter(requestiter.RequestIter): if len(self.buffer) < self.request.limit: self.left = len(self.buffer) else: - self.request.offset += len(result.photos) + self.request = dataclasses.replace(self.request, offset=self.request.offset + len(result.photos)) else: # Some broadcast channels have a photo that this request doesn't # retrieve for whatever random reason the Telegram server feels. @@ -368,8 +363,11 @@ class _ProfilePhotoIter(requestiter.RequestIter): if len(result.messages) < self.request.limit: self.left = len(self.buffer) elif result.messages: - self.request.add_offset = 0 - self.request.offset_id = result.messages[-1].id + self.request = dataclasses.replace( + self.request, + add_offset=0, + offset_id=result.messages[-1].id + ) def get_participants( diff --git a/telethon/_client/dialogs.py b/telethon/_client/dialogs.py index e3832ee8..e0f32547 100644 --- a/telethon/_client/dialogs.py +++ b/telethon/_client/dialogs.py @@ -2,6 +2,7 @@ import asyncio import inspect import itertools import typing +import dataclasses from .. import errors, _tl from .._misc import helpers, utils, requestiter, hints @@ -49,7 +50,7 @@ class _DialogsIter(requestiter.RequestIter): self.ignore_migrated = ignore_migrated async def _load_next_chunk(self): - self.request.limit = min(self.left, _MAX_CHUNK_SIZE) + self.request = dataclasses.replace(self.request, limit=min(self.left, _MAX_CHUNK_SIZE)) r = await self.client(self.request) self.total = getattr(r, 'count', len(r.dialogs)) @@ -103,10 +104,13 @@ class _DialogsIter(requestiter.RequestIter): for d in reversed(r.dialogs) )), None) - self.request.exclude_pinned = True - self.request.offset_id = last_message.id if last_message else 0 - self.request.offset_date = last_message.date if last_message else None - self.request.offset_peer = self.buffer[-1].input_entity + self.request = dataclasses.replace( + self.request, + exclude_pinned=True, + offset_id=last_message.id if last_message else 0, + offset_date=last_message.date if last_message else None, + offset_peer=self.buffer[-1].input_entity, + ) class _DraftsIter(requestiter.RequestIter): diff --git a/telethon/_client/downloads.py b/telethon/_client/downloads.py index 5bf8f01f..67c4622d 100644 --- a/telethon/_client/downloads.py +++ b/telethon/_client/downloads.py @@ -5,6 +5,7 @@ import pathlib import typing import inspect import asyncio +import dataclasses from .._crypto import AES from .._misc import utils, helpers, requestiter, tlobject, hints, enums @@ -55,7 +56,7 @@ class _DirectDownloadIter(requestiter.RequestIter): self.left = len(self.buffer) await self.close() else: - self.request.offset += self._stride + self.request = dataclasses.replace(self.request, offset=self.request.offset + self._stride) async def _request(self): try: @@ -102,7 +103,7 @@ class _DirectDownloadIter(requestiter.RequestIter): if document.id != self.request.location.id: raise - self.request.location.file_reference = document.file_reference + self.request.location = dataclasses.replace(self.request.location, file_reference=document.file_reference) return await self._request() async def close(self): @@ -134,18 +135,18 @@ class _GenericDownloadIter(_DirectDownloadIter): before = self.request.offset # 1.2. We have to fetch from a valid offset, so remove that bad part - self.request.offset -= bad + self.request = dataclasses.replace(self.request, offset=self.request.offset - bad) done = False while not done and len(data) - bad < self._chunk_size: cur = await self._request() - self.request.offset += self.request.limit + self.request = dataclasses.replace(self.request, offset=self.request.offset - self.request.limit) data += cur done = len(cur) < self.request.limit # 1.3 Restore our last desired offset - self.request.offset = before + self.request = dataclasses.replace(self.request, offset=before) # 2. Fill the buffer with the data we have # 2.1. Slicing `bytes` is expensive, yield `memoryview` instead @@ -157,7 +158,7 @@ class _GenericDownloadIter(_DirectDownloadIter): self.buffer.append(mem[i:i + self._chunk_size]) # 2.3. We will yield this offset, so move to the next one - self.request.offset += self._stride + self.request = dataclasses.replace(self.request, offset=self.request.offset + self._stride) # 2.4. If we are in the last chunk, we will return the last partial data if done: @@ -172,7 +173,7 @@ class _GenericDownloadIter(_DirectDownloadIter): # 3. Be careful with the offsets. Re-fetching a bit of data # is fine, since it greatly simplifies things. # TODO Try to not re-fetch data - self.request.offset -= self._stride + self.request = dataclasses.replace(self.request, offset=self.request.offset - self._stride) async def download_profile_photo( diff --git a/telethon/_client/messageparse.py b/telethon/_client/messageparse.py index 1366bca3..ab828209 100644 --- a/telethon/_client/messageparse.py +++ b/telethon/_client/messageparse.py @@ -87,7 +87,7 @@ def _get_response_message(self: 'TelegramClient', request, result, input_chat): elif isinstance(update, ( _tl.UpdateNewChannelMessage, _tl.UpdateNewMessage)): - update.message = _custom.Message._new(self, update.message, entities, input_chat) + message = _custom.Message._new(self, update.message, entities, input_chat) # Pinning a message with `updatePinnedMessage` seems to # always produce a service message we can't map so return @@ -97,20 +97,20 @@ def _get_response_message(self: 'TelegramClient', request, result, input_chat): # # TODO this method is getting messier and messier as time goes on if hasattr(request, 'random_id') or utils.is_list_like(request): - id_to_message[update.message.id] = update.message + id_to_message[message.id] = message else: - return update.message + return message elif (isinstance(update, _tl.UpdateEditMessage) and helpers._entity_type(request.peer) != helpers._EntityType.CHANNEL): - update.message = _custom.Message._new(self, update.message, entities, input_chat) + message = _custom.Message._new(self, update.message, entities, input_chat) # Live locations use `sendMedia` but Telegram responds with # `updateEditMessage`, which means we won't have `id` field. if hasattr(request, 'random_id'): - id_to_message[update.message.id] = update.message - elif request.id == update.message.id: - return update.message + id_to_message[message.id] = message + elif request.id == message.id: + return message elif (isinstance(update, _tl.UpdateEditChannelMessage) and utils.get_peer_id(request.peer) == diff --git a/telethon/_client/messages.py b/telethon/_client/messages.py index 1d85a5ee..e2065c7b 100644 --- a/telethon/_client/messages.py +++ b/telethon/_client/messages.py @@ -3,6 +3,7 @@ import itertools import time import typing import warnings +import dataclasses from .._misc import helpers, utils, requestiter, hints from ..types import _custom @@ -142,7 +143,7 @@ class _MessagesIter(requestiter.RequestIter): and offset_date and not search and not offset_id: async for m in self.client.get_messages( self.entity, 1, offset_date=offset_date): - self.request.offset_id = m.id + 1 + self.request = dataclasses.replace(self.request, offset_id=m.id + 1) else: self.request = _tl.fn.messages.GetHistory( peer=self.entity, @@ -178,10 +179,10 @@ class _MessagesIter(requestiter.RequestIter): self.last_id = 0 if self.reverse else float('inf') async def _load_next_chunk(self): - self.request.limit = min(self.left, _MAX_CHUNK_SIZE) + self.request = dataclasses.replace(self.request, limit=min(self.left, _MAX_CHUNK_SIZE)) if self.reverse and self.request.limit != _MAX_CHUNK_SIZE: # Remember that we need -limit when going in reverse - self.request.add_offset = self.add_offset - self.request.limit + self.request = dataclasses.replace(self.request, add_offset=self.add_offset - self.request.limit) r = await self.client(self.request) self.total = getattr(r, 'count', len(r.messages)) @@ -241,28 +242,28 @@ class _MessagesIter(requestiter.RequestIter): """ After making the request, update its offset with the last message. """ - self.request.offset_id = last_message.id + self.request = dataclasses.replace(self.request, offset_id=last_message.id) if self.reverse: # We want to skip the one we already have - self.request.offset_id += 1 + self.request = dataclasses.replace(self.request, offset_id=self.request.offset_id + 1) if isinstance(self.request, _tl.fn.messages.Search): # Unlike getHistory and searchGlobal that use *offset* date, # this is *max* date. This means that doing a search in reverse # will break it. Since it's not really needed once we're going # (only for the first request), it's safe to just clear it off. - self.request.max_date = None + self.request = dataclasses.replace(self.request, max_date=None) else: # getHistory, searchGlobal and getReplies call it offset_date - self.request.offset_date = last_message.date + self.request = dataclasses.replace(self.request, offset_date=last_message.date) if isinstance(self.request, _tl.fn.messages.SearchGlobal): if last_message.input_chat: - self.request.offset_peer = last_message.input_chat + self.request = dataclasses.replace(self.request, offset_peer=last_message.input_chat) else: - self.request.offset_peer = _tl.InputPeerEmpty() + self.request = dataclasses.replace(self.request, offset_peer=_tl.InputPeerEmpty()) - self.request.offset_rate = getattr(response, 'next_rate', 0) + self.request = dataclasses.replace(self.request, offset_rate=getattr(response, 'next_rate', 0)) class _IDsIter(requestiter.RequestIter): diff --git a/telethon/_client/users.py b/telethon/_client/users.py index 763a31c4..b44564e1 100644 --- a/telethon/_client/users.py +++ b/telethon/_client/users.py @@ -3,6 +3,7 @@ import datetime import itertools import time import typing +import dataclasses from ..errors._custom import MultiError from ..errors._rpcbase import RpcError, ServerError, FloodError, InvalidDcError, UnauthorizedError @@ -370,8 +371,7 @@ async def _get_input_dialog(self: 'TelegramClient', dialog): """ try: if dialog.SUBCLASS_OF_ID == 0xa21c9795: # crc32(b'InputDialogPeer') - dialog.peer = await self.get_input_entity(dialog.peer) - return dialog + return dataclasses.replace(dialog, peer=await self.get_input_entity(dialog.peer)) elif dialog.SUBCLASS_OF_ID == 0xc91c90b6: # crc32(b'InputPeer') return _tl.InputDialogPeer(dialog) except AttributeError: @@ -388,7 +388,7 @@ async def _get_input_notify(self: 'TelegramClient', notify): try: if notify.SUBCLASS_OF_ID == 0x58981615: if isinstance(notify, _tl.InputNotifyPeer): - notify.peer = await self.get_input_entity(notify.peer) + return dataclasses.replace(notify, peer=await self.get_input_entity(notify.peer)) return notify except AttributeError: pass