mirror of
https://github.com/LonamiWebs/Telethon.git
synced 2024-11-29 21:03:45 +03:00
431 lines
14 KiB
Python
431 lines
14 KiB
Python
"""Various helpers not related to the Telegram API itself"""
|
|
import asyncio
|
|
import io
|
|
import enum
|
|
import os
|
|
import struct
|
|
import inspect
|
|
import logging
|
|
import functools
|
|
from pathlib import Path
|
|
from hashlib import sha1
|
|
|
|
|
|
class _EntityType(enum.Enum):
|
|
USER = 0
|
|
CHAT = 1
|
|
CHANNEL = 2
|
|
|
|
|
|
_log = logging.getLogger(__name__)
|
|
|
|
|
|
# region Multiple utilities
|
|
|
|
|
|
def generate_random_long(signed=True):
|
|
"""Generates a random long integer (8 bytes), which is optionally signed"""
|
|
return int.from_bytes(os.urandom(8), signed=signed, byteorder='little')
|
|
|
|
|
|
def ensure_parent_dir_exists(file_path):
|
|
"""Ensures that the parent directory exists"""
|
|
parent = os.path.dirname(file_path)
|
|
if parent:
|
|
os.makedirs(parent, exist_ok=True)
|
|
|
|
|
|
def add_surrogate(text):
|
|
return ''.join(
|
|
# SMP -> Surrogate Pairs (Telegram offsets are calculated with these).
|
|
# See https://en.wikipedia.org/wiki/Plane_(Unicode)#Overview for more.
|
|
''.join(chr(y) for y in struct.unpack('<HH', x.encode('utf-16le')))
|
|
if (0x10000 <= ord(x) <= 0x10FFFF) else x for x in text
|
|
)
|
|
|
|
|
|
def del_surrogate(text):
|
|
return text.encode('utf-16', 'surrogatepass').decode('utf-16')
|
|
|
|
|
|
def within_surrogate(text, index, *, length=None):
|
|
"""
|
|
`True` if ``index`` is within a surrogate (before and after it, not at!).
|
|
"""
|
|
if length is None:
|
|
length = len(text)
|
|
|
|
return (
|
|
1 < index < len(text) and # in bounds
|
|
'\ud800' <= text[index - 1] <= '\udfff' and # previous is
|
|
'\ud800' <= text[index] <= '\udfff' # current is
|
|
)
|
|
|
|
|
|
def strip_text(text, entities):
|
|
"""
|
|
Strips whitespace from the given text modifying the provided entities.
|
|
|
|
This assumes that there are no overlapping entities, that their length
|
|
is greater or equal to one, and that their length is not out of bounds.
|
|
"""
|
|
if not entities:
|
|
return text.strip()
|
|
|
|
while text and text[-1].isspace():
|
|
e = entities[-1]
|
|
if e.offset + e.length == len(text):
|
|
if e.length == 1:
|
|
del entities[-1]
|
|
if not entities:
|
|
return text.strip()
|
|
else:
|
|
e.length -= 1
|
|
text = text[:-1]
|
|
|
|
while text and text[0].isspace():
|
|
for i in reversed(range(len(entities))):
|
|
e = entities[i]
|
|
if e.offset != 0:
|
|
e.offset -= 1
|
|
continue
|
|
|
|
if e.length == 1:
|
|
del entities[0]
|
|
if not entities:
|
|
return text.lstrip()
|
|
else:
|
|
e.length -= 1
|
|
|
|
text = text[1:]
|
|
|
|
return text
|
|
|
|
|
|
def retry_range(retries, force_retry=True):
|
|
"""
|
|
Generates an integer sequence starting from 1. If `retries` is
|
|
not a zero or a positive integer value, the sequence will be
|
|
infinite, otherwise it will end at `retries + 1`.
|
|
"""
|
|
|
|
# We need at least one iteration even if the retries are 0
|
|
# when force_retry is True.
|
|
if force_retry and not (retries is None or retries < 0):
|
|
retries += 1
|
|
|
|
attempt = 0
|
|
while attempt != retries:
|
|
attempt += 1
|
|
yield attempt
|
|
|
|
|
|
|
|
async def _maybe_await(value):
|
|
if inspect.isawaitable(value):
|
|
return await value
|
|
else:
|
|
return value
|
|
|
|
|
|
async def _cancel(log, **tasks):
|
|
"""
|
|
Helper to cancel one or more tasks gracefully, logging exceptions.
|
|
"""
|
|
for name, task in tasks.items():
|
|
if not task:
|
|
continue
|
|
|
|
task.cancel()
|
|
try:
|
|
await task
|
|
except asyncio.CancelledError:
|
|
pass
|
|
except RuntimeError:
|
|
# Probably: RuntimeError: await wasn't used with future
|
|
#
|
|
# See: https://github.com/python/cpython/blob/12d3061c7819a73d891dcce44327410eaf0e1bc2/Lib/asyncio/futures.py#L265
|
|
#
|
|
# Happens with _asyncio.Task instances (in "Task cancelling" state)
|
|
# trying to SIGINT the program right during initial connection, on
|
|
# _recv_loop coroutine (but we're creating its task explicitly with
|
|
# a loop, so how can it bug out like this?).
|
|
#
|
|
# Since we're aware of this error there's no point in logging it.
|
|
# *May* be https://bugs.python.org/issue37172
|
|
pass
|
|
except AssertionError as e:
|
|
# In Python 3.6, the above RuntimeError is an AssertionError
|
|
# See https://github.com/python/cpython/blob/7df32f844efed33ca781a016017eab7050263b90/Lib/asyncio/futures.py#L328
|
|
if e.args != ("yield from wasn't used with future",):
|
|
log.exception('Unhandled exception from %s after cancelling '
|
|
'%s (%s)', name, type(task), task)
|
|
except Exception:
|
|
log.exception('Unhandled exception from %s after cancelling '
|
|
'%s (%s)', name, type(task), task)
|
|
|
|
|
|
def _entity_type(entity):
|
|
# This could be a `utils` method that just ran a few `isinstance` on
|
|
# `utils.get_peer(...)`'s result. However, there are *a lot* of auto
|
|
# casts going on, plenty of calls and temporary short-lived objects.
|
|
#
|
|
# So we just check if a string is in the class name.
|
|
# Still, assert that it's the right type to not return false results.
|
|
try:
|
|
if entity.SUBCLASS_OF_ID not in (
|
|
0x2d45687, # crc32(b'Peer')
|
|
0xc91c90b6, # crc32(b'InputPeer')
|
|
0xe669bf46, # crc32(b'InputUser')
|
|
0x40f202fd, # crc32(b'InputChannel')
|
|
0x2da17977, # crc32(b'User')
|
|
0xc5af5d94, # crc32(b'Chat')
|
|
0x1f4661b9, # crc32(b'UserFull')
|
|
0xd49a2697, # crc32(b'ChatFull')
|
|
):
|
|
raise TypeError('{} does not have any entity type'.format(entity))
|
|
except AttributeError:
|
|
raise TypeError('{} is not a TLObject, cannot determine entity type'.format(entity))
|
|
|
|
name = entity.__class__.__name__
|
|
if 'User' in name:
|
|
return _EntityType.USER
|
|
elif 'Chat' in name:
|
|
return _EntityType.CHAT
|
|
elif 'Channel' in name:
|
|
return _EntityType.CHANNEL
|
|
elif 'Self' in name:
|
|
return _EntityType.USER
|
|
|
|
# 'Empty' in name or not found, we don't care, not a valid entity.
|
|
raise TypeError('{} does not have any entity type'.format(entity))
|
|
|
|
|
|
def pretty_print(obj, indent=None, max_depth=float('inf')):
|
|
max_depth -= 1
|
|
if max_depth < 0:
|
|
return '...'
|
|
|
|
to_d = getattr(obj, '_to_dict', None) or getattr(obj, 'to_dict', None)
|
|
if callable(to_d):
|
|
obj = to_d()
|
|
|
|
if indent is None:
|
|
if isinstance(obj, dict):
|
|
return '{}({})'.format(obj.get('_', 'dict'), ', '.join(
|
|
'{}={}'.format(k, pretty_print(v, indent, max_depth))
|
|
for k, v in obj.items() if k != '_'
|
|
))
|
|
elif isinstance(obj, str) or isinstance(obj, bytes):
|
|
return repr(obj)
|
|
elif hasattr(obj, '__iter__'):
|
|
return '[{}]'.format(
|
|
', '.join(pretty_print(x, indent, max_depth) for x in obj)
|
|
)
|
|
else:
|
|
return repr(obj)
|
|
else:
|
|
result = []
|
|
|
|
if isinstance(obj, dict):
|
|
result.append(obj.get('_', 'dict'))
|
|
result.append('(')
|
|
if obj:
|
|
result.append('\n')
|
|
indent += 1
|
|
for k, v in obj.items():
|
|
if k == '_':
|
|
continue
|
|
result.append('\t' * indent)
|
|
result.append(k)
|
|
result.append('=')
|
|
result.append(pretty_print(v, indent, max_depth))
|
|
result.append(',\n')
|
|
result.pop() # last ',\n'
|
|
indent -= 1
|
|
result.append('\n')
|
|
result.append('\t' * indent)
|
|
result.append(')')
|
|
|
|
elif isinstance(obj, str) or isinstance(obj, bytes):
|
|
result.append(repr(obj))
|
|
|
|
elif hasattr(obj, '__iter__'):
|
|
result.append('[\n')
|
|
indent += 1
|
|
for x in obj:
|
|
result.append('\t' * indent)
|
|
result.append(pretty_print(x, indent, max_depth))
|
|
result.append(',\n')
|
|
indent -= 1
|
|
result.append('\t' * indent)
|
|
result.append(']')
|
|
|
|
else:
|
|
result.append(repr(obj))
|
|
|
|
return ''.join(result)
|
|
|
|
|
|
# endregion
|
|
|
|
# region Cryptographic related utils
|
|
|
|
|
|
def generate_key_data_from_nonce(server_nonce, new_nonce):
|
|
"""Generates the key data corresponding to the given nonce"""
|
|
server_nonce = server_nonce.to_bytes(16, 'little', signed=True)
|
|
new_nonce = new_nonce.to_bytes(32, 'little', signed=True)
|
|
hash1 = sha1(new_nonce + server_nonce).digest()
|
|
hash2 = sha1(server_nonce + new_nonce).digest()
|
|
hash3 = sha1(new_nonce + new_nonce).digest()
|
|
|
|
key = hash1 + hash2[:12]
|
|
iv = hash2[12:20] + hash3 + new_nonce[:4]
|
|
return key, iv
|
|
|
|
|
|
# endregion
|
|
|
|
# region Custom Classes
|
|
|
|
|
|
class TotalList(list):
|
|
"""
|
|
A list with an extra `total` property, which may not match its `len`
|
|
since the total represents the total amount of items *available*
|
|
somewhere else, not the items *in this list*.
|
|
|
|
Examples:
|
|
|
|
.. code-block:: python
|
|
|
|
# Telethon returns these lists in some cases (for example,
|
|
# only when a chunk is returned, but the "total" count
|
|
# is available).
|
|
result = await client.get_messages(chat, limit=10)
|
|
|
|
print(result.total) # large number
|
|
print(len(result)) # 10
|
|
print(result[0]) # latest message
|
|
|
|
for x in result: # show the 10 messages
|
|
print(x.text)
|
|
|
|
"""
|
|
def __init__(self, *args, **kwargs):
|
|
super().__init__(*args, **kwargs)
|
|
self.total = 0
|
|
|
|
def __str__(self):
|
|
return '[{}, total={}]'.format(
|
|
', '.join(str(x) for x in self), self.total)
|
|
|
|
def __repr__(self):
|
|
return '[{}, total={}]'.format(
|
|
', '.join(repr(x) for x in self), self.total)
|
|
|
|
|
|
class _FileStream(io.IOBase):
|
|
"""
|
|
Proxy around things that represent a file and need to be used as streams
|
|
which may or not need to be closed.
|
|
|
|
This will handle `pathlib.Path`, `str` paths, in-memory `bytes`, and
|
|
anything IO-like (including `aiofiles`).
|
|
|
|
It also provides access to the name and file size (also necessary).
|
|
"""
|
|
def __init__(self, file, *, file_size=None):
|
|
if isinstance(file, Path):
|
|
file = str(file.absolute())
|
|
|
|
self._file = file
|
|
self._name = None
|
|
self._size = file_size
|
|
self._stream = None
|
|
self._close_stream = None
|
|
|
|
async def __aenter__(self):
|
|
if isinstance(self._file, str):
|
|
self._name = os.path.basename(self._file)
|
|
self._size = os.path.getsize(self._file)
|
|
self._stream = open(self._file, 'rb')
|
|
self._close_stream = True
|
|
|
|
elif isinstance(self._file, bytes):
|
|
self._size = len(self._file)
|
|
self._stream = io.BytesIO(self._file)
|
|
self._close_stream = True
|
|
|
|
elif not callable(getattr(self._file, 'read', None)):
|
|
raise TypeError('file description should have a `read` method')
|
|
|
|
elif self._size is not None:
|
|
self._name = getattr(self._file, 'name', None)
|
|
self._stream = self._file
|
|
self._close_stream = False
|
|
|
|
else:
|
|
if callable(getattr(self._file, 'seekable', None)):
|
|
seekable = await _maybe_await(self._file.seekable())
|
|
else:
|
|
seekable = False
|
|
|
|
if seekable:
|
|
pos = await _maybe_await(self._file.tell())
|
|
await _maybe_await(self._file.seek(0, os.SEEK_END))
|
|
self._size = await _maybe_await(self._file.tell())
|
|
await _maybe_await(self._file.seek(pos, os.SEEK_SET))
|
|
self._stream = self._file
|
|
self._close_stream = False
|
|
else:
|
|
_log.warning(
|
|
'Could not determine file size beforehand so the entire '
|
|
'file will be read in-memory')
|
|
|
|
data = await _maybe_await(self._file.read())
|
|
self._size = len(data)
|
|
self._stream = io.BytesIO(data)
|
|
self._close_stream = True
|
|
|
|
return self
|
|
|
|
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
|
if self._close_stream and self._stream:
|
|
await _maybe_await(self._stream.close())
|
|
|
|
@property
|
|
def file_size(self):
|
|
return self._size
|
|
|
|
@property
|
|
def name(self):
|
|
return self._name
|
|
|
|
# Proxy all the methods. Doesn't need to be readable (makes multiline edits easier)
|
|
def read(self, *args, **kwargs): return self._stream.read(*args, **kwargs)
|
|
def readinto(self, *args, **kwargs): return self._stream.readinto(*args, **kwargs)
|
|
def write(self, *args, **kwargs): return self._stream.write(*args, **kwargs)
|
|
def fileno(self, *args, **kwargs): return self._stream.fileno(*args, **kwargs)
|
|
def flush(self, *args, **kwargs): return self._stream.flush(*args, **kwargs)
|
|
def isatty(self, *args, **kwargs): return self._stream.isatty(*args, **kwargs)
|
|
def readable(self, *args, **kwargs): return self._stream.readable(*args, **kwargs)
|
|
def readline(self, *args, **kwargs): return self._stream.readline(*args, **kwargs)
|
|
def readlines(self, *args, **kwargs): return self._stream.readlines(*args, **kwargs)
|
|
def seek(self, *args, **kwargs): return self._stream.seek(*args, **kwargs)
|
|
def seekable(self, *args, **kwargs): return self._stream.seekable(*args, **kwargs)
|
|
def tell(self, *args, **kwargs): return self._stream.tell(*args, **kwargs)
|
|
def truncate(self, *args, **kwargs): return self._stream.truncate(*args, **kwargs)
|
|
def writable(self, *args, **kwargs): return self._stream.writable(*args, **kwargs)
|
|
def writelines(self, *args, **kwargs): return self._stream.writelines(*args, **kwargs)
|
|
|
|
# close is special because it will be called by __del__ but we do NOT
|
|
# want to close the file unless we have to (we're just a wrapper).
|
|
# Instead, we do nothing (we should be used through the decorator which
|
|
# has its own mechanism to close the file correctly).
|
|
def close(self, *args, **kwargs):
|
|
pass
|
|
|
|
# endregion
|