Merge branch 'master' into asyncio

This commit is contained in:
Lonami Exo 2017-10-16 10:03:01 +02:00
commit 917665852d
17 changed files with 450 additions and 245 deletions

View File

@ -129,13 +129,10 @@ class BinaryReader:
return False
# If there was still no luck, give up
self.seek(-4) # Go back
raise TypeNotFoundError(constructor_id)
# Create an empty instance of the class and
# fill it with the read attributes
result = clazz.empty()
result.on_response(self)
return result
return clazz.from_reader(self)
def tgread_vector(self):
"""Reads a vector (a list) of Telegram objects"""

View File

@ -15,7 +15,7 @@ class TcpClient:
if isinstance(timeout, timedelta):
self.timeout = timeout.seconds
elif isinstance(timeout, int) or isinstance(timeout, float):
elif isinstance(timeout, (int, float)):
self.timeout = float(timeout)
else:
raise ValueError('Invalid timeout type', type(timeout))

View File

@ -124,7 +124,6 @@ async def _do_authentication(connection):
raise AssertionError(server_dh_inner)
if server_dh_inner.nonce != res_pq.nonce:
print(server_dh_inner.nonce, res_pq.nonce)
raise SecurityError('Invalid nonce in encrypted answer')
if server_dh_inner.server_nonce != res_pq.server_nonce:

View File

@ -143,28 +143,25 @@ class Connection:
# TODO We don't want another call to this method that could
# potentially await on another self.read(n). Is this guaranteed
# by asyncio?
packet_length_bytes = await self.read(4)
packet_length = int.from_bytes(packet_length_bytes, 'little')
packet_len_seq = await self.read(8) # 4 and 4
packet_len, seq = struct.unpack('<ii', packet_len_seq)
seq_bytes = await self.read(4)
seq = int.from_bytes(seq_bytes, 'little')
body = await self.read(packet_len - 12)
checksum = struct.unpack('<I', await self.read(4))[0]
body = await self.read(packet_length - 12)
checksum = int.from_bytes(await self.read(4), 'little')
valid_checksum = crc32(packet_length_bytes + seq_bytes + body)
valid_checksum = crc32(packet_len_seq + body)
if checksum != valid_checksum:
raise InvalidChecksumError(checksum, valid_checksum)
return body
async def _recv_intermediate(self):
return await self.read(int.from_bytes(self.read(4), 'little'))
return await self.read(struct.unpack('<i', await self.read(4))[0])
async def _recv_abridged(self):
length = int.from_bytes(self.read(1), 'little')
length = struct.unpack('<B', await self.read(1))[0]
if length >= 127:
length = int.from_bytes(self.read(3) + b'\0', 'little')
length = struct.unpack('<i', await self.read(3) + b'\0')[0]
return await self.read(length << 2)

View File

@ -11,7 +11,10 @@ from ..errors import (
from ..extensions import BinaryReader
from ..tl import TLMessage, MessageContainer, GzipPacked
from ..tl.all_tlobjects import tlobjects
from ..tl.types import MsgsAck
from ..tl.types import (
MsgsAck, Pong, BadServerSalt, BadMsgNotification,
MsgNewDetailedInfo, NewSessionCreated, MsgDetailedInfo
)
from ..tl.functions.auth import LogOutRequest
logging.getLogger(__name__).addHandler(logging.NullHandler())
@ -180,24 +183,33 @@ class MtProtoSender:
if code == 0xf35c6d01: # rpc_result, (response of an RPC call)
return await self._handle_rpc_result(msg_id, sequence, reader)
if code == 0x347773c5: # pong
if code == Pong.CONSTRUCTOR_ID:
return await self._handle_pong(msg_id, sequence, reader)
if code == 0x73f1f8dc: # msg_container
if code == MessageContainer.CONSTRUCTOR_ID:
return await self._handle_container(msg_id, sequence, reader, state)
if code == 0x3072cfa1: # gzip_packed
if code == GzipPacked.CONSTRUCTOR_ID:
return await self._handle_gzip_packed(msg_id, sequence, reader, state)
if code == 0xedab447b: # bad_server_salt
if code == BadServerSalt.CONSTRUCTOR_ID:
return await self._handle_bad_server_salt(msg_id, sequence, reader)
if code == 0xa7eff811: # bad_msg_notification
if code == BadMsgNotification.CONSTRUCTOR_ID:
return await self._handle_bad_msg_notification(msg_id, sequence, reader)
# msgs_ack, it may handle the request we wanted
if code == 0x62d6b459:
if code == MsgDetailedInfo.CONSTRUCTOR_ID:
return await self._handle_msg_detailed_info(msg_id, sequence, reader)
if code == MsgNewDetailedInfo.CONSTRUCTOR_ID:
return await self._handle_msg_new_detailed_info(msg_id, sequence, reader)
if code == NewSessionCreated.CONSTRUCTOR_ID:
return await self._handle_new_session_created(msg_id, sequence, reader)
if code == MsgsAck.CONSTRUCTOR_ID: # may handle the request we wanted
ack = reader.tgread_object()
assert isinstance(ack, MsgsAck)
# Ignore every ack request *unless* when logging out, when it's
# when it seems to only make sense. We also need to set a non-None
# result since Telegram doesn't send the response for these.
@ -219,7 +231,12 @@ class MtProtoSender:
return True
self._logger.debug('Unknown message: {}'.format(hex(code)))
self._logger.debug(
'[WARN] Unknown message: {}, data left in the buffer: {}'
.format(
hex(code), repr(reader.get_bytes()[reader.tell_position():])
)
)
return False
# endregion
@ -239,7 +256,7 @@ class MtProtoSender:
the given type, or returns None if it's not found/doesn't match.
"""
message = self._pending_receive.get(msg_id, None)
if isinstance(message.request, t):
if message and isinstance(message.request, t):
return self._pending_receive.pop(msg_id).request
def _clear_all_pending(self):
@ -249,12 +266,13 @@ class MtProtoSender:
async def _handle_pong(self, msg_id, sequence, reader):
self._logger.debug('Handling pong')
reader.read_int(signed=False) # code
received_msg_id = reader.read_long()
pong = reader.tgread_object()
assert isinstance(pong, Pong)
request = self._pop_request(received_msg_id)
request = self._pop_request(pong.msg_id)
if request:
self._logger.debug('Pong confirmed a request')
request.result = pong
request.confirm_received.set()
return True
@ -278,14 +296,15 @@ class MtProtoSender:
async def _handle_bad_server_salt(self, msg_id, sequence, reader):
self._logger.debug('Handling bad server salt')
reader.read_int(signed=False) # code
bad_msg_id = reader.read_long()
reader.read_int() # bad_msg_seq_no
reader.read_int() # error_code
new_salt = reader.read_long(signed=False)
self.session.salt = new_salt
bad_salt = reader.tgread_object()
assert isinstance(bad_salt, BadServerSalt)
request = self._pop_request(bad_msg_id)
# Our salt is unsigned, but the objects work with signed salts
self.session.salt = struct.unpack(
'<Q', struct.pack('<q', bad_salt.new_server_salt)
)[0]
request = self._pop_request(bad_salt.bad_msg_id)
if request:
await self.send(request)
@ -293,31 +312,53 @@ class MtProtoSender:
async def _handle_bad_msg_notification(self, msg_id, sequence, reader):
self._logger.debug('Handling bad message notification')
reader.read_int(signed=False) # code
reader.read_long() # request_id
reader.read_int() # request_sequence
bad_msg = reader.tgread_object()
assert isinstance(bad_msg, BadMsgNotification)
error_code = reader.read_int()
error = BadMessageError(error_code)
if error_code in (16, 17):
error = BadMessageError(bad_msg.error_code)
if bad_msg.error_code in (16, 17):
# sent msg_id too low or too high (respectively).
# Use the current msg_id to determine the right time offset.
self.session.update_time_offset(correct_msg_id=msg_id)
self._logger.debug('Read Bad Message error: ' + str(error))
self._logger.debug('Attempting to use the correct time offset.')
return True
elif error_code == 32:
elif bad_msg.error_code == 32:
# msg_seqno too low, so just pump it up by some "large" amount
# TODO A better fix would be to start with a new fresh session ID
self.session._sequence += 64
return True
elif error_code == 33:
elif bad_msg.error_code == 33:
# msg_seqno too high never seems to happen but just in case
self.session._sequence -= 16
return True
else:
raise error
async def _handle_msg_detailed_info(self, msg_id, sequence, reader):
msg_new = reader.tgread_object()
assert isinstance(msg_new, MsgDetailedInfo)
# TODO For now, simply ack msg_new.answer_msg_id
# Relevant tdesktop source code: https://goo.gl/VvpCC6
await self._send_acknowledge(msg_new.answer_msg_id)
return True
async def _handle_msg_new_detailed_info(self, msg_id, sequence, reader):
msg_new = reader.tgread_object()
assert isinstance(msg_new, MsgNewDetailedInfo)
# TODO For now, simply ack msg_new.answer_msg_id
# Relevant tdesktop source code: https://goo.gl/G7DPsR
await self._send_acknowledge(msg_new.answer_msg_id)
return True
async def _handle_new_session_created(self, msg_id, sequence, reader):
new_session = reader.tgread_object()
assert isinstance(new_session, NewSessionCreated)
# TODO https://goo.gl/LMyN7A
return True
async def _handle_rpc_result(self, msg_id, sequence, reader):
self._logger.debug('Handling RPC result')
reader.read_int(signed=False) # code
@ -346,25 +387,26 @@ class MtProtoSender:
# else TODO Where should this error be reported?
# Read may be async. Can an error not-belong to a request?
self._logger.debug('Read RPC error: %s', str(error))
else:
if request:
self._logger.debug('Reading request response')
if inner_code == 0x3072cfa1: # GZip packed
unpacked_data = gzip.decompress(reader.tgread_bytes())
with BinaryReader(unpacked_data) as compressed_reader:
request.on_response(compressed_reader)
else:
reader.seek(-4)
request.on_response(reader)
return True # All contents were read okay
self.session.process_entities(request.result)
request.confirm_received.set()
return True
elif request:
self._logger.debug('Reading request response')
if inner_code == 0x3072cfa1: # GZip packed
unpacked_data = gzip.decompress(reader.tgread_bytes())
with BinaryReader(unpacked_data) as compressed_reader:
request.on_response(compressed_reader)
else:
# If it's really a result for RPC from previous connection
# session, it will be skipped by the handle_container()
self._logger.debug('Lost request will be skipped.')
return False
reader.seek(-4)
request.on_response(reader)
self.session.process_entities(request.result)
request.confirm_received.set()
return True
# If it's really a result for RPC from previous connection
# session, it will be skipped by the handle_container()
self._logger.debug('Lost request will be skipped.')
return False
async def _handle_gzip_packed(self, msg_id, sequence, reader, state):
self._logger.debug('Handling gzip packed data')

View File

@ -1,5 +1,6 @@
import logging
import os
import warnings
from datetime import timedelta, datetime
from hashlib import md5
from io import BytesIO
@ -55,7 +56,7 @@ class TelegramBareClient:
"""
# Current TelegramClient version
__version__ = '0.15.1'
__version__ = '0.15.2'
# TODO Make this thread-safe, all connections share the same DC
_dc_options = None
@ -386,7 +387,7 @@ class TelegramBareClient:
try:
for _ in range(retries):
result = await self._invoke(sender, *requests)
if result:
if result is not None:
return result
raise ValueError('Number of retries reached 0.')
@ -412,7 +413,7 @@ class TelegramBareClient:
pass # We will just retry
except ConnectionResetError:
if not self._authorized:
if not self._user_connected:
# Only attempt reconnecting if we're authorized
raise
@ -459,10 +460,14 @@ class TelegramBareClient:
'[ERROR] Telegram is having some internal issues', e
)
except FloodWaitError:
sender.disconnect()
self.disconnect()
raise
except FloodWaitError as e:
if e.seconds > self.session.flood_sleep_threshold | 0:
raise
self._logger.debug(
'Sleep of %d seconds below threshold, sleeping' % e.seconds
)
sleep(e.seconds)
# Some really basic functionality
@ -609,10 +614,8 @@ class TelegramBareClient:
cdn_decrypter = None
try:
offset_index = 0
offset = 0
while True:
offset = offset_index * part_size
try:
if cdn_decrypter:
result = await cdn_decrypter.get_file()
@ -633,7 +636,7 @@ class TelegramBareClient:
client = await self._get_exported_client(e.new_dc)
continue
offset_index += 1
offset += part_size
# If we have received no data (0 bytes), the file is over
# So there is nothing left to download and write
@ -670,6 +673,9 @@ class TelegramBareClient:
def add_update_handler(self, handler):
"""Adds an update handler (a function which takes a TLObject,
an update, as its parameter) and listens for updates"""
if not self.updates.get_workers:
warnings.warn("There are no update workers running, so adding an update handler will have no effect.")
sync = not self.updates.handlers
self.updates.handlers.append(handler)
if sync:

View File

@ -15,6 +15,7 @@ from .errors import (
)
from .network import ConnectionMode
from .tl import TLObject
from .tl.custom import Draft
from .tl.entity_database import EntityDatabase
from .tl.functions.account import (
GetPasswordRequest
@ -28,8 +29,8 @@ from .tl.functions.contacts import (
)
from .tl.functions.messages import (
GetDialogsRequest, GetHistoryRequest, ReadHistoryRequest, SendMediaRequest,
SendMessageRequest, GetChatsRequest
)
SendMessageRequest, GetChatsRequest,
GetAllDraftsRequest)
from .tl.functions import channels
from .tl.functions import messages
@ -46,7 +47,7 @@ from .tl.types import (
InputMediaUploadedDocument, InputMediaUploadedPhoto, InputPeerEmpty,
Message, MessageMediaContact, MessageMediaDocument, MessageMediaPhoto,
InputUserSelf, UserProfilePhoto, ChatPhoto, UpdateMessageID,
UpdateNewMessage, UpdateShortSentMessage,
UpdateNewChannelMessage, UpdateNewMessage, UpdateShortSentMessage,
PeerUser, InputPeerUser, InputPeerChat, InputPeerChannel)
from .tl.types.messages import DialogsSlice
@ -227,7 +228,7 @@ class TelegramClient(TelegramBareClient):
if limit is None:
limit = float('inf')
dialogs = {} # Use Dialog.top_message as identifier to avoid dupes
dialogs = {} # Use peer id as identifier to avoid dupes
messages = {} # Used later for sorting TODO also return these?
entities = {}
while len(dialogs) < limit:
@ -242,7 +243,7 @@ class TelegramClient(TelegramBareClient):
break
for d in r.dialogs:
dialogs[d.top_message] = d
dialogs[utils.get_peer_id(d.peer, True)] = d
for m in r.messages:
messages[m.id] = m
@ -277,9 +278,20 @@ class TelegramClient(TelegramBareClient):
[utils.find_user_or_chat(d.peer, entities, entities) for d in ds]
)
# endregion
async def get_drafts(self): # TODO: Ability to provide a `filter`
"""
Gets all open draft messages.
# region Message requests
Returns a list of custom `Draft` objects that are easy to work with: You can call
`draft.set_message('text')` to change the message, or delete it through `draft.delete()`.
:return List[telethon.tl.custom.Draft]: A list of open drafts
"""
response = await self(GetAllDraftsRequest())
self.session.process_entities(response)
self.session.generate_sequence(response.seq)
drafts = [Draft._from_update(self, u) for u in response.updates]
return drafts
async def send_message(self,
entity,
@ -322,7 +334,7 @@ class TelegramClient(TelegramBareClient):
break
for update in result.updates:
if isinstance(update, UpdateNewMessage):
if isinstance(update, (UpdateNewChannelMessage, UpdateNewMessage)):
if update.message.id == msg_id:
return update.message
@ -463,9 +475,13 @@ class TelegramClient(TelegramBareClient):
async def send_file(self, entity, file, caption='',
force_document=False, progress_callback=None,
reply_to=None,
attributes=None,
**kwargs):
"""Sends a file to the specified entity.
The file may either be a path, a byte array, or a stream.
Note that if a byte array or a stream is given, a filename
or its type won't be inferred, and it will be sent as an
"unnamed application/octet-stream".
An optional caption can also be specified for said file.
@ -482,6 +498,10 @@ class TelegramClient(TelegramBareClient):
The "reply_to" parameter works exactly as the one on .send_message.
If "attributes" is set to be a list of DocumentAttribute's, these
will override the automatically inferred ones (so that you can
modify the file name of the file sent for instance).
If "is_voice_note" in kwargs, despite its value, and the file is
sent as a document, it will be sent as a voice note.
@ -512,16 +532,28 @@ class TelegramClient(TelegramBareClient):
# Determine mime-type and attributes
# Take the first element by using [0] since it returns a tuple
mime_type = guess_type(file)[0]
attributes = [
attr_dict = {
DocumentAttributeFilename:
DocumentAttributeFilename(os.path.basename(file))
# TODO If the input file is an audio, find out:
# Performer and song title and add DocumentAttributeAudio
]
}
else:
attributes = [DocumentAttributeFilename('unnamed')]
attr_dict = {
DocumentAttributeFilename:
DocumentAttributeFilename('unnamed')
}
if 'is_voice_note' in kwargs:
attributes.append(DocumentAttributeAudio(0, voice=True))
attr_dict[DocumentAttributeAudio] = \
DocumentAttributeAudio(0, voice=True)
# Now override the attributes if any. As we have a dict of
# {cls: instance}, we can override any class with the list
# of attributes provided by the user easily.
if attributes:
for a in attributes:
attr_dict[type(a)] = a
# Ensure we have a mime type, any; but it cannot be None
# 'The "octet-stream" subtype is used to indicate that a body
@ -532,7 +564,7 @@ class TelegramClient(TelegramBareClient):
media = InputMediaUploadedDocument(
file=file_handle,
mime_type=mime_type,
attributes=attributes,
attributes=list(attr_dict.values()),
caption=caption
)
@ -852,19 +884,19 @@ class TelegramClient(TelegramBareClient):
# crc32(b'InputPeer') and crc32(b'Peer')
type(entity).SUBCLASS_OF_ID in (0xc91c90b6, 0x2d45687)):
ie = await self.get_input_entity(entity)
result = None
if isinstance(ie, InputPeerUser):
result = await self(GetUsersRequest([ie]))
await self(GetUsersRequest([ie]))
elif isinstance(ie, InputPeerChat):
result = await self(GetChatsRequest([ie.chat_id]))
await self(GetChatsRequest([ie.chat_id]))
elif isinstance(ie, InputPeerChannel):
result = await self(GetChannelsRequest([ie]))
if result:
self.session.process_entities(result)
try:
return self.session.entities[ie]
except KeyError:
pass
await self(GetChannelsRequest([ie]))
try:
# session.process_entities has been called in the MtProtoSender
# with the result of these calls, so they should now be on the
# entities database.
return self.session.entities[ie]
except KeyError:
pass
if isinstance(entity, str):
return await self._get_entity_from_string(entity)
@ -880,10 +912,11 @@ class TelegramClient(TelegramBareClient):
phone = EntityDatabase.parse_phone(string)
if phone:
entity = phone
self.session.process_entities(await self(GetContactsRequest(0)))
await self(GetContactsRequest(0))
else:
entity = string.strip('@').lower()
self.session.process_entities(await self(ResolveUsernameRequest(entity)))
await self(ResolveUsernameRequest(entity))
# MtProtoSender will call .process_entities on the requests made
try:
return self.session.entities[entity]
@ -930,9 +963,17 @@ class TelegramClient(TelegramBareClient):
)
if self.session.save_entities:
# Not found, look in the dialogs (this will save the users)
await self.get_dialogs(limit=None)
# Not found, look in the latest dialogs.
# This is useful if for instance someone just sent a message but
# the updates didn't specify who, as this person or chat should
# be in the latest dialogs.
await self(GetDialogsRequest(
offset_date=None,
offset_id=0,
offset_peer=InputPeerEmpty(),
limit=0,
exclude_pinned=True
))
try:
return self.session.entities.get_input_entity(peer)
except KeyError:

View File

@ -0,0 +1 @@
from .draft import Draft

View File

@ -0,0 +1,80 @@
from ..functions.messages import SaveDraftRequest
from ..types import UpdateDraftMessage
class Draft:
"""
Custom class that encapsulates a draft on the Telegram servers, providing
an abstraction to change the message conveniently. The library will return
instances of this class when calling `client.get_drafts()`.
"""
def __init__(self, client, peer, draft):
self._client = client
self._peer = peer
self.text = draft.message
self.date = draft.date
self.no_webpage = draft.no_webpage
self.reply_to_msg_id = draft.reply_to_msg_id
self.entities = draft.entities
@classmethod
def _from_update(cls, client, update):
if not isinstance(update, UpdateDraftMessage):
raise ValueError(
'You can only create a new `Draft` from a corresponding '
'`UpdateDraftMessage` object.'
)
return cls(client=client, peer=update.peer, draft=update.draft)
@property
def entity(self):
return self._client.get_entity(self._peer)
@property
def input_entity(self):
return self._client.get_input_entity(self._peer)
def set_message(self, text, no_webpage=None, reply_to_msg_id=None, entities=None):
"""
Changes the draft message on the Telegram servers. The changes are
reflected in this object. Changing only individual attributes like for
example the `reply_to_msg_id` should be done by providing the current
values of this object, like so:
draft.set_message(
draft.text,
no_webpage=draft.no_webpage,
reply_to_msg_id=NEW_VALUE,
entities=draft.entities
)
:param str text: New text of the draft
:param bool no_webpage: Whether to attach a web page preview
:param int reply_to_msg_id: Message id to reply to
:param list entities: A list of formatting entities
:return bool: `True` on success
"""
result = self._client(SaveDraftRequest(
peer=self._peer,
message=text,
no_webpage=no_webpage,
reply_to_msg_id=reply_to_msg_id,
entities=entities
))
if result:
self.text = text
self.no_webpage = no_webpage
self.reply_to_msg_id = reply_to_msg_id
self.entities = entities
return result
def delete(self):
"""
Deletes this draft
:return bool: `True` on success
"""
return self.set_message(text='')

View File

@ -70,9 +70,7 @@ class EntityDatabase:
getattr(p, 'access_hash', 0) # chats won't have hash
if self.enabled_full:
if isinstance(e, User) \
or isinstance(e, Chat) \
or isinstance(e, Channel):
if isinstance(e, (User, Chat, Channel)):
new.append(e)
except ValueError:
pass
@ -118,47 +116,64 @@ class EntityDatabase:
if phone:
self._username_id[phone] = marked_id
def __getitem__(self, key):
"""Accepts a digit only string as phone number,
otherwise it's treated as an username.
def _parse_key(self, key):
"""Parses the given string, integer or TLObject key into a
marked user ID ready for use on self._entities.
If an integer is given, it's treated as the ID of the desired User.
The ID given won't try to be guessed as the ID of a chat or channel,
as there may be an user with that ID, and it would be unreliable.
If a callable key is given, the entity will be passed to the
function, and if it returns a true-like value, the marked ID
for such entity will be returned.
If a Peer is given (PeerUser, PeerChat, PeerChannel),
its specific entity is retrieved as User, Chat or Channel.
Note that megagroups are channels with .megagroup = True.
Raises ValueError if it cannot be parsed.
"""
if isinstance(key, str):
phone = EntityDatabase.parse_phone(key)
if phone:
return self._phone_id[phone]
else:
key = key.lstrip('@').lower()
return self._entities[self._username_id[key]]
try:
if phone:
return self._phone_id[phone]
else:
return self._username_id[key.lstrip('@').lower()]
except KeyError as e:
raise ValueError() from e
if isinstance(key, int):
return self._entities[key] # normal IDs are assumed users
return key # normal IDs are assumed users
if isinstance(key, TLObject):
sc = type(key).SUBCLASS_OF_ID
if sc == 0x2d45687:
# Subclass of "Peer"
return self._entities[utils.get_peer_id(key, add_mark=True)]
elif sc in {0x2da17977, 0xc5af5d94, 0x6d44b7db}:
# Subclass of "User", "Chat" or "Channel"
return key
return utils.get_peer_id(key, add_mark=True)
raise KeyError(key)
if callable(key):
for k, v in self._entities.items():
if key(v):
return k
raise ValueError()
def __getitem__(self, key):
"""See the ._parse_key() docstring for possible values of the key"""
try:
return self._entities[self._parse_key(key)]
except (ValueError, KeyError) as e:
raise KeyError(key) from e
def __delitem__(self, key):
target = self[key]
del self._entities[key]
if getattr(target, 'username'):
del self._username_id[target.username]
try:
old = self._entities.pop(self._parse_key(key))
# Try removing the username and phone (if pop didn't fail),
# since the entity may have no username or phone, just ignore
# errors. It should be there if we popped the entity correctly.
try:
del self._username_id[getattr(old, 'username', None)]
except KeyError:
pass
# TODO Allow search by name by tokenizing the input and return a list
try:
del self._phone_id[getattr(old, 'phone', None)]
except KeyError:
pass
except (ValueError, KeyError) as e:
raise KeyError(key) from e
@staticmethod
def parse_phone(phone):
@ -172,8 +187,10 @@ class EntityDatabase:
def get_input_entity(self, peer):
try:
i, k = utils.get_peer_id(peer, add_mark=True, get_kind=True)
h = self._input_entities[i]
i = utils.get_peer_id(peer, add_mark=True)
h = self._input_entities[i] # we store the IDs marked
i, k = utils.resolve_id(i) # removes the mark and returns kind
if k == PeerUser:
return InputPeerUser(i, h)
elif k == PeerChat:

View File

@ -34,5 +34,5 @@ class GzipPacked(TLObject):
@staticmethod
def read(reader):
reader.read_int(signed=False) # code
assert reader.read_int(signed=False) == GzipPacked.CONSTRUCTOR_ID
return gzip.decompress(reader.tgread_bytes())

View File

@ -36,6 +36,7 @@ class Session:
self.lang_pack = session.lang_pack
self.report_errors = session.report_errors
self.save_entities = session.save_entities
self.flood_sleep_threshold = session.flood_sleep_threshold
else: # str / None
self.session_user_id = session_user_id
@ -49,6 +50,7 @@ class Session:
self.lang_pack = ''
self.report_errors = True
self.save_entities = True
self.flood_sleep_threshold = 60
self.id = helpers.generate_random_long(signed=False)
self._sequence = 0

View File

@ -1,3 +1,4 @@
from datetime import datetime
from threading import Event
@ -19,18 +20,14 @@ class TLObject:
"""
if indent is None:
if isinstance(obj, TLObject):
children = obj.to_dict(recursive=False)
if children:
return '{}: {}'.format(
type(obj).__name__, TLObject.pretty_format(children)
)
else:
return type(obj).__name__
return '{}({})'.format(type(obj).__name__, ', '.join(
'{}={}'.format(k, TLObject.pretty_format(v))
for k, v in obj.to_dict(recursive=False).items()
))
if isinstance(obj, dict):
return '{{{}}}'.format(', '.join(
'{}: {}'.format(
k, TLObject.pretty_format(v)
) for k, v in obj.items()
'{}: {}'.format(k, TLObject.pretty_format(v))
for k, v in obj.items()
))
elif isinstance(obj, str) or isinstance(obj, bytes):
return repr(obj)
@ -38,31 +35,36 @@ class TLObject:
return '[{}]'.format(
', '.join(TLObject.pretty_format(x) for x in obj)
)
elif isinstance(obj, datetime):
return 'datetime.fromtimestamp({})'.format(obj.timestamp())
else:
return str(obj)
return repr(obj)
else:
result = []
if isinstance(obj, TLObject):
result.append(type(obj).__name__)
children = obj.to_dict(recursive=False)
if children:
result.append(': ')
result.append(TLObject.pretty_format(
obj.to_dict(recursive=False), indent
))
if isinstance(obj, TLObject) or isinstance(obj, dict):
if isinstance(obj, dict):
d = obj
start, end, sep = '{', '}', ': '
else:
d = obj.to_dict(recursive=False)
start, end, sep = '(', ')', '='
result.append(type(obj).__name__)
elif isinstance(obj, dict):
result.append('{\n')
indent += 1
for k, v in obj.items():
result.append(start)
if d:
result.append('\n')
indent += 1
for k, v in d.items():
result.append('\t' * indent)
result.append(k)
result.append(sep)
result.append(TLObject.pretty_format(v, indent))
result.append(',\n')
result.pop() # last ',\n'
indent -= 1
result.append('\n')
result.append('\t' * indent)
result.append(k)
result.append(': ')
result.append(TLObject.pretty_format(v, indent))
result.append(',\n')
indent -= 1
result.append('\t' * indent)
result.append('}')
result.append(end)
elif isinstance(obj, str) or isinstance(obj, bytes):
result.append(repr(obj))
@ -78,8 +80,13 @@ class TLObject:
result.append('\t' * indent)
result.append(']')
elif isinstance(obj, datetime):
result.append('datetime.fromtimestamp(')
result.append(repr(obj.timestamp()))
result.append(')')
else:
result.append(str(obj))
result.append(repr(obj))
return ''.join(result)
@ -121,5 +128,6 @@ class TLObject:
def to_bytes(self):
return b''
def on_response(self, reader):
pass
@staticmethod
def from_reader(reader):
return TLObject()

View File

@ -1,4 +1,5 @@
import logging
import pickle
from collections import deque
from datetime import datetime
from threading import RLock, Event, Thread
@ -27,6 +28,7 @@ class UpdateState:
self._updates_lock = RLock()
self._updates_available = Event()
self._updates = deque()
self._latest_updates = deque(maxlen=10)
self._logger = logging.getLogger(__name__)
@ -141,6 +143,26 @@ class UpdateState:
self._state.pts = pts
# TODO There must be a better way to handle updates rather than
# keeping a queue with the latest updates only, and handling
# the 'pts' correctly should be enough. However some updates
# like UpdateUserStatus (even inside UpdateShort) will be called
# repeatedly very often if invoking anything inside an update
# handler. TODO Figure out why.
"""
client = TelegramClient('anon', api_id, api_hash, update_workers=1)
client.connect()
def handle(u):
client.get_me()
client.add_update_handler(handle)
input('Enter to exit.')
"""
data = pickle.dumps(update.to_dict())
if data in self._latest_updates:
return # Duplicated too
self._latest_updates.append(data)
if type(update).SUBCLASS_OF_ID == 0x8af52aac: # crc32(b'Updates')
# Expand "Updates" into "Update", and pass these to callbacks.
# Since .users and .chats have already been processed, we
@ -149,8 +171,7 @@ class UpdateState:
self._updates.append(update.update)
self._updates_available.set()
elif isinstance(update, tl.Updates) or \
isinstance(update, tl.UpdatesCombined):
elif isinstance(update, (tl.Updates, tl.UpdatesCombined)):
self._updates.extend(update.updates)
self._updates_available.set()

View File

@ -20,8 +20,8 @@ from .tl.types import (
GeoPointEmpty, InputGeoPointEmpty, Photo, InputPhoto, PhotoEmpty,
InputPhotoEmpty, FileLocation, ChatPhotoEmpty, UserProfilePhotoEmpty,
FileLocationUnavailable, InputMediaUploadedDocument,
InputMediaUploadedPhoto,
DocumentAttributeFilename)
InputMediaUploadedPhoto, DocumentAttributeFilename, photos
)
def get_display_name(entity):
@ -37,7 +37,7 @@ def get_display_name(entity):
else:
return '(No name)'
if isinstance(entity, Chat) or isinstance(entity, Channel):
if isinstance(entity, (Chat, Channel)):
return entity.title
return '(unknown)'
@ -50,8 +50,7 @@ def get_extension(media):
"""Gets the corresponding extension for any Telegram media"""
# Photos are always compressed as .jpg by Telegram
if (isinstance(media, UserProfilePhoto) or isinstance(media, ChatPhoto) or
isinstance(media, MessageMediaPhoto)):
if isinstance(media, (UserProfilePhoto, ChatPhoto, MessageMediaPhoto)):
return '.jpg'
# Documents will come with a mime type
@ -87,12 +86,10 @@ def get_input_peer(entity, allow_self=True):
else:
return InputPeerUser(entity.id, entity.access_hash)
if any(isinstance(entity, c) for c in (
Chat, ChatEmpty, ChatForbidden)):
if isinstance(entity, (Chat, ChatEmpty, ChatForbidden)):
return InputPeerChat(entity.id)
if any(isinstance(entity, c) for c in (
Channel, ChannelForbidden)):
if isinstance(entity, (Channel, ChannelForbidden)):
return InputPeerChannel(entity.id, entity.access_hash)
# Less common cases
@ -122,7 +119,7 @@ def get_input_channel(entity):
if type(entity).SUBCLASS_OF_ID == 0x40f202fd: # crc32(b'InputChannel')
return entity
if isinstance(entity, Channel) or isinstance(entity, ChannelForbidden):
if isinstance(entity, (Channel, ChannelForbidden)):
return InputChannel(entity.id, entity.access_hash)
if isinstance(entity, InputPeerChannel):
@ -188,6 +185,9 @@ def get_input_photo(photo):
if type(photo).SUBCLASS_OF_ID == 0x846363e0: # crc32(b'InputPhoto')
return photo
if isinstance(photo, photos.Photo):
photo = photo.photo
if isinstance(photo, Photo):
return InputPhoto(id=photo.id, access_hash=photo.access_hash)
@ -263,7 +263,7 @@ def get_input_media(media, user_caption=None, is_photo=False):
if isinstance(media, MessageMediaGame):
return InputMediaGame(id=media.game.id)
if isinstance(media, ChatPhoto) or isinstance(media, UserProfilePhoto):
if isinstance(media, (ChatPhoto, UserProfilePhoto)):
if isinstance(media.photo_big, FileLocationUnavailable):
return get_input_media(media.photo_small, is_photo=True)
else:
@ -288,10 +288,9 @@ def get_input_media(media, user_caption=None, is_photo=False):
venue_id=media.venue_id
)
if any(isinstance(media, t) for t in (
if isinstance(media, (
MessageMediaEmpty, MessageMediaUnsupported,
FileLocationUnavailable, ChatPhotoEmpty,
UserProfilePhotoEmpty)):
ChatPhotoEmpty, UserProfilePhotoEmpty, FileLocationUnavailable)):
return InputMediaEmpty()
if isinstance(media, Message):
@ -300,16 +299,14 @@ def get_input_media(media, user_caption=None, is_photo=False):
_raise_cast_fail(media, 'InputMedia')
def get_peer_id(peer, add_mark=False, get_kind=False):
def get_peer_id(peer, add_mark=False):
"""Finds the ID of the given peer, and optionally converts it to
the "bot api" format if 'add_mark' is set to True.
If 'get_kind', the kind will be returned as a second value.
"""
# First we assert it's a Peer TLObject, or early return for integers
if not isinstance(peer, TLObject):
if isinstance(peer, int):
return (peer, PeerUser) if get_kind else peer
return peer
else:
_raise_cast_fail(peer, 'int')
@ -318,25 +315,20 @@ def get_peer_id(peer, add_mark=False, get_kind=False):
peer = get_input_peer(peer, allow_self=False)
# Set the right ID/kind, or raise if the TLObject is not recognised
i, k = None, None
if isinstance(peer, PeerUser) or isinstance(peer, InputPeerUser):
i, k = peer.user_id, PeerUser
elif isinstance(peer, PeerChat) or isinstance(peer, InputPeerChat):
i, k = peer.chat_id, PeerChat
elif isinstance(peer, PeerChannel) or isinstance(peer, InputPeerChannel):
i, k = peer.channel_id, PeerChannel
else:
_raise_cast_fail(peer, 'int')
if add_mark:
if k == PeerChat:
i = -i
elif k == PeerChannel:
if isinstance(peer, (PeerUser, InputPeerUser)):
return peer.user_id
elif isinstance(peer, (PeerChat, InputPeerChat)):
return -peer.chat_id if add_mark else peer.chat_id
elif isinstance(peer, (PeerChannel, InputPeerChannel)):
i = peer.channel_id
if add_mark:
# Concat -100 through math tricks, .to_supergroup() on Madeline
# IDs will be strictly positive -> log works
i = -(i + pow(10, math.floor(math.log10(i) + 3)))
return -(i + pow(10, math.floor(math.log10(i) + 3)))
else:
return i
return (i, k) if get_kind else i # return kind only if get_kind
_raise_cast_fail(peer, 'int')
def resolve_id(marked_id):
@ -375,11 +367,7 @@ def find_user_or_chat(peer, users, chats):
def get_appropriated_part_size(file_size):
"""Gets the appropriated part size when uploading or downloading files,
given an initial file size"""
if file_size <= 1048576: # 1MB
return 32
if file_size <= 10485760: # 10MB
return 64
if file_size <= 393216000: # 375MB
if file_size <= 104857600: # 100MB
return 128
if file_size <= 786432000: # 750MB
return 256

View File

@ -143,7 +143,7 @@ class TLGenerator:
builder.writeln(
'from {}.utils import get_input_peer, '
'get_input_channel, get_input_user, '
'get_input_media'.format('.' * depth)
'get_input_media, get_input_photo'.format('.' * depth)
)
# Import 'os' for those needing access to 'os.urandom()'
@ -335,32 +335,28 @@ class TLGenerator:
builder.writeln('))')
builder.end_block()
# Write the empty() function, which returns an "empty"
# instance, in which all attributes are set to None
# Write the static from_reader(reader) function
builder.writeln('@staticmethod')
builder.writeln('def empty():')
builder.writeln('def from_reader(reader):')
for arg in tlobject.args:
TLGenerator.write_read_code(
builder, arg, tlobject.args, name='_' + arg.name
)
builder.writeln('return {}({})'.format(
tlobject.class_name(), ', '.join('None' for _ in range(len(args)))
tlobject.class_name(), ', '.join(
'{0}=_{0}'.format(a.name) for a in tlobject.sorted_args()
if not a.flag_indicator and not a.generic_definition
)
))
builder.end_block()
# Write the on_response(self, reader) function
builder.writeln('def on_response(self, reader):')
# Do not read constructor's ID, since
# that's already been read somewhere else
# Only requests can have a different response that's not their
# serialized body, that is, we'll be setting their .result.
if tlobject.is_function:
builder.writeln('def on_response(self, reader):')
TLGenerator.write_request_result_code(builder, tlobject)
else:
if tlobject.args:
for arg in tlobject.args:
TLGenerator.write_onresponse_code(
builder, arg, tlobject.args
)
else:
# If there were no arguments, we still need an
# on_response method, and hence "pass" if empty
builder.writeln('pass')
builder.end_block()
builder.end_block()
# Write the __str__(self) and stringify(self) functions
builder.writeln('def __str__(self):')
@ -406,6 +402,8 @@ class TLGenerator:
TLGenerator.write_get_input(builder, arg, 'get_input_user')
elif arg.type == 'InputMedia' and tlobject.is_function:
TLGenerator.write_get_input(builder, arg, 'get_input_media')
elif arg.type == 'InputPhoto' and tlobject.is_function:
TLGenerator.write_get_input(builder, arg, 'get_input_photo')
else:
builder.writeln('self.{0} = {0}'.format(arg.name))
@ -549,9 +547,10 @@ class TLGenerator:
return True # Something was written
@staticmethod
def write_onresponse_code(builder, arg, args, name=None):
def write_read_code(builder, arg, args, name):
"""
Writes the receive code for the given argument
Writes the read code for the given argument, setting the
arg.name variable to its read value.
:param builder: The source code builder
:param arg: The argument to write
@ -565,12 +564,17 @@ class TLGenerator:
if arg.generic_definition:
return # Do nothing, this only specifies a later type
if name is None:
name = 'self.{}'.format(arg.name)
# The argument may be a flag, only write that flag was given!
was_flag = False
if arg.is_flag:
# Treat 'true' flags as a special case, since they're true if
# they're set, and nothing else needs to actually be read.
if 'true' == arg.type:
builder.writeln(
'{} = bool(flags & {})'.format(name, 1 << arg.flag_index)
)
return
was_flag = True
builder.writeln('if flags & {}:'.format(
1 << arg.flag_index
@ -585,11 +589,10 @@ class TLGenerator:
builder.writeln("reader.read_int()")
builder.writeln('{} = []'.format(name))
builder.writeln('_len = reader.read_int()')
builder.writeln('for _ in range(_len):')
builder.writeln('for _ in range(reader.read_int()):')
# Temporary disable .is_vector, not to enter this if again
arg.is_vector = False
TLGenerator.write_onresponse_code(builder, arg, args, name='_x')
TLGenerator.write_read_code(builder, arg, args, name='_x')
builder.writeln('{}.append(_x)'.format(name))
arg.is_vector = True
@ -642,7 +645,10 @@ class TLGenerator:
builder.end_block()
if was_flag:
builder.end_block()
builder.current_indent -= 1
builder.writeln('else:')
builder.writeln('{} = None'.format(name))
builder.current_indent -= 1
# Restore .is_flag
arg.is_flag = True

View File

@ -107,17 +107,17 @@ class CryptoTests(unittest.TestCase):
@staticmethod
def test_generate_key_data_from_nonce():
server_nonce = b'I am the server nonce.'
new_nonce = b'I am a new calculated nonce.'
server_nonce = int.from_bytes(b'The 16-bit nonce', byteorder='little')
new_nonce = int.from_bytes(b'The new, calculated 32-bit nonce', byteorder='little')
key, iv = utils.generate_key_data_from_nonce(server_nonce, new_nonce)
expected_key = b'?\xc4\xbd\xdf\rWU\x8a\xf5\x0f+V\xdc\x96up\x1d\xeeG\x00\x81|\x1eg\x8a\x8f{\xf0y\x80\xda\xde'
expected_iv = b'Q\x9dpZ\xb7\xdd\xcb\x82_\xfa\xf4\x90\xecn\x10\x9cD\xd2\x01\x8d\x83\xa0\xa4^\xb8\x91,\x7fI am'
expected_key = b'/\xaa\x7f\xa1\xfcs\xef\xa0\x99zh\x03M\xa4\x8e\xb4\xab\x0eE]b\x95|\xfe\xc0\xf8\x1f\xd4\xa0\xd4\xec\x91'
expected_iv = b'\xf7\xae\xe3\xc8+=\xc2\xb8\xd1\xe1\x1b\x0e\x10\x07\x9fn\x9e\xdc\x960\x05\xf9\xea\xee\x8b\xa1h The '
assert key == expected_key, 'Key ("{}") does not equal expected ("{}")'.format(
key, expected_key)
assert iv == expected_iv, 'Key ("{}") does not equal expected ("{}")'.format(
key, expected_iv)
assert iv == expected_iv, 'IV ("{}") does not equal expected ("{}")'.format(
iv, expected_iv)
@staticmethod
def test_fingerprint_from_key():