Merge branch 'master' into asyncio

This commit is contained in:
Lonami Exo 2017-10-16 10:03:01 +02:00
commit 917665852d
17 changed files with 450 additions and 245 deletions

View File

@ -129,13 +129,10 @@ class BinaryReader:
return False return False
# If there was still no luck, give up # If there was still no luck, give up
self.seek(-4) # Go back
raise TypeNotFoundError(constructor_id) raise TypeNotFoundError(constructor_id)
# Create an empty instance of the class and return clazz.from_reader(self)
# fill it with the read attributes
result = clazz.empty()
result.on_response(self)
return result
def tgread_vector(self): def tgread_vector(self):
"""Reads a vector (a list) of Telegram objects""" """Reads a vector (a list) of Telegram objects"""

View File

@ -15,7 +15,7 @@ class TcpClient:
if isinstance(timeout, timedelta): if isinstance(timeout, timedelta):
self.timeout = timeout.seconds self.timeout = timeout.seconds
elif isinstance(timeout, int) or isinstance(timeout, float): elif isinstance(timeout, (int, float)):
self.timeout = float(timeout) self.timeout = float(timeout)
else: else:
raise ValueError('Invalid timeout type', type(timeout)) raise ValueError('Invalid timeout type', type(timeout))

View File

@ -124,7 +124,6 @@ async def _do_authentication(connection):
raise AssertionError(server_dh_inner) raise AssertionError(server_dh_inner)
if server_dh_inner.nonce != res_pq.nonce: if server_dh_inner.nonce != res_pq.nonce:
print(server_dh_inner.nonce, res_pq.nonce)
raise SecurityError('Invalid nonce in encrypted answer') raise SecurityError('Invalid nonce in encrypted answer')
if server_dh_inner.server_nonce != res_pq.server_nonce: if server_dh_inner.server_nonce != res_pq.server_nonce:

View File

@ -143,28 +143,25 @@ class Connection:
# TODO We don't want another call to this method that could # TODO We don't want another call to this method that could
# potentially await on another self.read(n). Is this guaranteed # potentially await on another self.read(n). Is this guaranteed
# by asyncio? # by asyncio?
packet_length_bytes = await self.read(4) packet_len_seq = await self.read(8) # 4 and 4
packet_length = int.from_bytes(packet_length_bytes, 'little') packet_len, seq = struct.unpack('<ii', packet_len_seq)
seq_bytes = await self.read(4) body = await self.read(packet_len - 12)
seq = int.from_bytes(seq_bytes, 'little') checksum = struct.unpack('<I', await self.read(4))[0]
body = await self.read(packet_length - 12) valid_checksum = crc32(packet_len_seq + body)
checksum = int.from_bytes(await self.read(4), 'little')
valid_checksum = crc32(packet_length_bytes + seq_bytes + body)
if checksum != valid_checksum: if checksum != valid_checksum:
raise InvalidChecksumError(checksum, valid_checksum) raise InvalidChecksumError(checksum, valid_checksum)
return body return body
async def _recv_intermediate(self): async def _recv_intermediate(self):
return await self.read(int.from_bytes(self.read(4), 'little')) return await self.read(struct.unpack('<i', await self.read(4))[0])
async def _recv_abridged(self): async def _recv_abridged(self):
length = int.from_bytes(self.read(1), 'little') length = struct.unpack('<B', await self.read(1))[0]
if length >= 127: if length >= 127:
length = int.from_bytes(self.read(3) + b'\0', 'little') length = struct.unpack('<i', await self.read(3) + b'\0')[0]
return await self.read(length << 2) return await self.read(length << 2)

View File

@ -11,7 +11,10 @@ from ..errors import (
from ..extensions import BinaryReader from ..extensions import BinaryReader
from ..tl import TLMessage, MessageContainer, GzipPacked from ..tl import TLMessage, MessageContainer, GzipPacked
from ..tl.all_tlobjects import tlobjects from ..tl.all_tlobjects import tlobjects
from ..tl.types import MsgsAck from ..tl.types import (
MsgsAck, Pong, BadServerSalt, BadMsgNotification,
MsgNewDetailedInfo, NewSessionCreated, MsgDetailedInfo
)
from ..tl.functions.auth import LogOutRequest from ..tl.functions.auth import LogOutRequest
logging.getLogger(__name__).addHandler(logging.NullHandler()) logging.getLogger(__name__).addHandler(logging.NullHandler())
@ -180,24 +183,33 @@ class MtProtoSender:
if code == 0xf35c6d01: # rpc_result, (response of an RPC call) if code == 0xf35c6d01: # rpc_result, (response of an RPC call)
return await self._handle_rpc_result(msg_id, sequence, reader) return await self._handle_rpc_result(msg_id, sequence, reader)
if code == 0x347773c5: # pong if code == Pong.CONSTRUCTOR_ID:
return await self._handle_pong(msg_id, sequence, reader) return await self._handle_pong(msg_id, sequence, reader)
if code == 0x73f1f8dc: # msg_container if code == MessageContainer.CONSTRUCTOR_ID:
return await self._handle_container(msg_id, sequence, reader, state) return await self._handle_container(msg_id, sequence, reader, state)
if code == 0x3072cfa1: # gzip_packed if code == GzipPacked.CONSTRUCTOR_ID:
return await self._handle_gzip_packed(msg_id, sequence, reader, state) return await self._handle_gzip_packed(msg_id, sequence, reader, state)
if code == 0xedab447b: # bad_server_salt if code == BadServerSalt.CONSTRUCTOR_ID:
return await self._handle_bad_server_salt(msg_id, sequence, reader) return await self._handle_bad_server_salt(msg_id, sequence, reader)
if code == 0xa7eff811: # bad_msg_notification if code == BadMsgNotification.CONSTRUCTOR_ID:
return await self._handle_bad_msg_notification(msg_id, sequence, reader) return await self._handle_bad_msg_notification(msg_id, sequence, reader)
# msgs_ack, it may handle the request we wanted if code == MsgDetailedInfo.CONSTRUCTOR_ID:
if code == 0x62d6b459: return await self._handle_msg_detailed_info(msg_id, sequence, reader)
if code == MsgNewDetailedInfo.CONSTRUCTOR_ID:
return await self._handle_msg_new_detailed_info(msg_id, sequence, reader)
if code == NewSessionCreated.CONSTRUCTOR_ID:
return await self._handle_new_session_created(msg_id, sequence, reader)
if code == MsgsAck.CONSTRUCTOR_ID: # may handle the request we wanted
ack = reader.tgread_object() ack = reader.tgread_object()
assert isinstance(ack, MsgsAck)
# Ignore every ack request *unless* when logging out, when it's # Ignore every ack request *unless* when logging out, when it's
# when it seems to only make sense. We also need to set a non-None # when it seems to only make sense. We also need to set a non-None
# result since Telegram doesn't send the response for these. # result since Telegram doesn't send the response for these.
@ -219,7 +231,12 @@ class MtProtoSender:
return True return True
self._logger.debug('Unknown message: {}'.format(hex(code))) self._logger.debug(
'[WARN] Unknown message: {}, data left in the buffer: {}'
.format(
hex(code), repr(reader.get_bytes()[reader.tell_position():])
)
)
return False return False
# endregion # endregion
@ -239,7 +256,7 @@ class MtProtoSender:
the given type, or returns None if it's not found/doesn't match. the given type, or returns None if it's not found/doesn't match.
""" """
message = self._pending_receive.get(msg_id, None) message = self._pending_receive.get(msg_id, None)
if isinstance(message.request, t): if message and isinstance(message.request, t):
return self._pending_receive.pop(msg_id).request return self._pending_receive.pop(msg_id).request
def _clear_all_pending(self): def _clear_all_pending(self):
@ -249,12 +266,13 @@ class MtProtoSender:
async def _handle_pong(self, msg_id, sequence, reader): async def _handle_pong(self, msg_id, sequence, reader):
self._logger.debug('Handling pong') self._logger.debug('Handling pong')
reader.read_int(signed=False) # code pong = reader.tgread_object()
received_msg_id = reader.read_long() assert isinstance(pong, Pong)
request = self._pop_request(received_msg_id) request = self._pop_request(pong.msg_id)
if request: if request:
self._logger.debug('Pong confirmed a request') self._logger.debug('Pong confirmed a request')
request.result = pong
request.confirm_received.set() request.confirm_received.set()
return True return True
@ -278,14 +296,15 @@ class MtProtoSender:
async def _handle_bad_server_salt(self, msg_id, sequence, reader): async def _handle_bad_server_salt(self, msg_id, sequence, reader):
self._logger.debug('Handling bad server salt') self._logger.debug('Handling bad server salt')
reader.read_int(signed=False) # code bad_salt = reader.tgread_object()
bad_msg_id = reader.read_long() assert isinstance(bad_salt, BadServerSalt)
reader.read_int() # bad_msg_seq_no
reader.read_int() # error_code
new_salt = reader.read_long(signed=False)
self.session.salt = new_salt
request = self._pop_request(bad_msg_id) # Our salt is unsigned, but the objects work with signed salts
self.session.salt = struct.unpack(
'<Q', struct.pack('<q', bad_salt.new_server_salt)
)[0]
request = self._pop_request(bad_salt.bad_msg_id)
if request: if request:
await self.send(request) await self.send(request)
@ -293,31 +312,53 @@ class MtProtoSender:
async def _handle_bad_msg_notification(self, msg_id, sequence, reader): async def _handle_bad_msg_notification(self, msg_id, sequence, reader):
self._logger.debug('Handling bad message notification') self._logger.debug('Handling bad message notification')
reader.read_int(signed=False) # code bad_msg = reader.tgread_object()
reader.read_long() # request_id assert isinstance(bad_msg, BadMsgNotification)
reader.read_int() # request_sequence
error_code = reader.read_int() error = BadMessageError(bad_msg.error_code)
error = BadMessageError(error_code) if bad_msg.error_code in (16, 17):
if error_code in (16, 17):
# sent msg_id too low or too high (respectively). # sent msg_id too low or too high (respectively).
# Use the current msg_id to determine the right time offset. # Use the current msg_id to determine the right time offset.
self.session.update_time_offset(correct_msg_id=msg_id) self.session.update_time_offset(correct_msg_id=msg_id)
self._logger.debug('Read Bad Message error: ' + str(error)) self._logger.debug('Read Bad Message error: ' + str(error))
self._logger.debug('Attempting to use the correct time offset.') self._logger.debug('Attempting to use the correct time offset.')
return True return True
elif error_code == 32: elif bad_msg.error_code == 32:
# msg_seqno too low, so just pump it up by some "large" amount # msg_seqno too low, so just pump it up by some "large" amount
# TODO A better fix would be to start with a new fresh session ID # TODO A better fix would be to start with a new fresh session ID
self.session._sequence += 64 self.session._sequence += 64
return True return True
elif error_code == 33: elif bad_msg.error_code == 33:
# msg_seqno too high never seems to happen but just in case # msg_seqno too high never seems to happen but just in case
self.session._sequence -= 16 self.session._sequence -= 16
return True return True
else: else:
raise error raise error
async def _handle_msg_detailed_info(self, msg_id, sequence, reader):
msg_new = reader.tgread_object()
assert isinstance(msg_new, MsgDetailedInfo)
# TODO For now, simply ack msg_new.answer_msg_id
# Relevant tdesktop source code: https://goo.gl/VvpCC6
await self._send_acknowledge(msg_new.answer_msg_id)
return True
async def _handle_msg_new_detailed_info(self, msg_id, sequence, reader):
msg_new = reader.tgread_object()
assert isinstance(msg_new, MsgNewDetailedInfo)
# TODO For now, simply ack msg_new.answer_msg_id
# Relevant tdesktop source code: https://goo.gl/G7DPsR
await self._send_acknowledge(msg_new.answer_msg_id)
return True
async def _handle_new_session_created(self, msg_id, sequence, reader):
new_session = reader.tgread_object()
assert isinstance(new_session, NewSessionCreated)
# TODO https://goo.gl/LMyN7A
return True
async def _handle_rpc_result(self, msg_id, sequence, reader): async def _handle_rpc_result(self, msg_id, sequence, reader):
self._logger.debug('Handling RPC result') self._logger.debug('Handling RPC result')
reader.read_int(signed=False) # code reader.read_int(signed=False) # code
@ -346,8 +387,9 @@ class MtProtoSender:
# else TODO Where should this error be reported? # else TODO Where should this error be reported?
# Read may be async. Can an error not-belong to a request? # Read may be async. Can an error not-belong to a request?
self._logger.debug('Read RPC error: %s', str(error)) self._logger.debug('Read RPC error: %s', str(error))
else: return True # All contents were read okay
if request:
elif request:
self._logger.debug('Reading request response') self._logger.debug('Reading request response')
if inner_code == 0x3072cfa1: # GZip packed if inner_code == 0x3072cfa1: # GZip packed
unpacked_data = gzip.decompress(reader.tgread_bytes()) unpacked_data = gzip.decompress(reader.tgread_bytes())
@ -360,7 +402,7 @@ class MtProtoSender:
self.session.process_entities(request.result) self.session.process_entities(request.result)
request.confirm_received.set() request.confirm_received.set()
return True return True
else:
# If it's really a result for RPC from previous connection # If it's really a result for RPC from previous connection
# session, it will be skipped by the handle_container() # session, it will be skipped by the handle_container()
self._logger.debug('Lost request will be skipped.') self._logger.debug('Lost request will be skipped.')

View File

@ -1,5 +1,6 @@
import logging import logging
import os import os
import warnings
from datetime import timedelta, datetime from datetime import timedelta, datetime
from hashlib import md5 from hashlib import md5
from io import BytesIO from io import BytesIO
@ -55,7 +56,7 @@ class TelegramBareClient:
""" """
# Current TelegramClient version # Current TelegramClient version
__version__ = '0.15.1' __version__ = '0.15.2'
# TODO Make this thread-safe, all connections share the same DC # TODO Make this thread-safe, all connections share the same DC
_dc_options = None _dc_options = None
@ -386,7 +387,7 @@ class TelegramBareClient:
try: try:
for _ in range(retries): for _ in range(retries):
result = await self._invoke(sender, *requests) result = await self._invoke(sender, *requests)
if result: if result is not None:
return result return result
raise ValueError('Number of retries reached 0.') raise ValueError('Number of retries reached 0.')
@ -412,7 +413,7 @@ class TelegramBareClient:
pass # We will just retry pass # We will just retry
except ConnectionResetError: except ConnectionResetError:
if not self._authorized: if not self._user_connected:
# Only attempt reconnecting if we're authorized # Only attempt reconnecting if we're authorized
raise raise
@ -459,11 +460,15 @@ class TelegramBareClient:
'[ERROR] Telegram is having some internal issues', e '[ERROR] Telegram is having some internal issues', e
) )
except FloodWaitError: except FloodWaitError as e:
sender.disconnect() if e.seconds > self.session.flood_sleep_threshold | 0:
self.disconnect()
raise raise
self._logger.debug(
'Sleep of %d seconds below threshold, sleeping' % e.seconds
)
sleep(e.seconds)
# Some really basic functionality # Some really basic functionality
def is_user_authorized(self): def is_user_authorized(self):
@ -609,10 +614,8 @@ class TelegramBareClient:
cdn_decrypter = None cdn_decrypter = None
try: try:
offset_index = 0 offset = 0
while True: while True:
offset = offset_index * part_size
try: try:
if cdn_decrypter: if cdn_decrypter:
result = await cdn_decrypter.get_file() result = await cdn_decrypter.get_file()
@ -633,7 +636,7 @@ class TelegramBareClient:
client = await self._get_exported_client(e.new_dc) client = await self._get_exported_client(e.new_dc)
continue continue
offset_index += 1 offset += part_size
# If we have received no data (0 bytes), the file is over # If we have received no data (0 bytes), the file is over
# So there is nothing left to download and write # So there is nothing left to download and write
@ -670,6 +673,9 @@ class TelegramBareClient:
def add_update_handler(self, handler): def add_update_handler(self, handler):
"""Adds an update handler (a function which takes a TLObject, """Adds an update handler (a function which takes a TLObject,
an update, as its parameter) and listens for updates""" an update, as its parameter) and listens for updates"""
if not self.updates.get_workers:
warnings.warn("There are no update workers running, so adding an update handler will have no effect.")
sync = not self.updates.handlers sync = not self.updates.handlers
self.updates.handlers.append(handler) self.updates.handlers.append(handler)
if sync: if sync:

View File

@ -15,6 +15,7 @@ from .errors import (
) )
from .network import ConnectionMode from .network import ConnectionMode
from .tl import TLObject from .tl import TLObject
from .tl.custom import Draft
from .tl.entity_database import EntityDatabase from .tl.entity_database import EntityDatabase
from .tl.functions.account import ( from .tl.functions.account import (
GetPasswordRequest GetPasswordRequest
@ -28,8 +29,8 @@ from .tl.functions.contacts import (
) )
from .tl.functions.messages import ( from .tl.functions.messages import (
GetDialogsRequest, GetHistoryRequest, ReadHistoryRequest, SendMediaRequest, GetDialogsRequest, GetHistoryRequest, ReadHistoryRequest, SendMediaRequest,
SendMessageRequest, GetChatsRequest SendMessageRequest, GetChatsRequest,
) GetAllDraftsRequest)
from .tl.functions import channels from .tl.functions import channels
from .tl.functions import messages from .tl.functions import messages
@ -46,7 +47,7 @@ from .tl.types import (
InputMediaUploadedDocument, InputMediaUploadedPhoto, InputPeerEmpty, InputMediaUploadedDocument, InputMediaUploadedPhoto, InputPeerEmpty,
Message, MessageMediaContact, MessageMediaDocument, MessageMediaPhoto, Message, MessageMediaContact, MessageMediaDocument, MessageMediaPhoto,
InputUserSelf, UserProfilePhoto, ChatPhoto, UpdateMessageID, InputUserSelf, UserProfilePhoto, ChatPhoto, UpdateMessageID,
UpdateNewMessage, UpdateShortSentMessage, UpdateNewChannelMessage, UpdateNewMessage, UpdateShortSentMessage,
PeerUser, InputPeerUser, InputPeerChat, InputPeerChannel) PeerUser, InputPeerUser, InputPeerChat, InputPeerChannel)
from .tl.types.messages import DialogsSlice from .tl.types.messages import DialogsSlice
@ -227,7 +228,7 @@ class TelegramClient(TelegramBareClient):
if limit is None: if limit is None:
limit = float('inf') limit = float('inf')
dialogs = {} # Use Dialog.top_message as identifier to avoid dupes dialogs = {} # Use peer id as identifier to avoid dupes
messages = {} # Used later for sorting TODO also return these? messages = {} # Used later for sorting TODO also return these?
entities = {} entities = {}
while len(dialogs) < limit: while len(dialogs) < limit:
@ -242,7 +243,7 @@ class TelegramClient(TelegramBareClient):
break break
for d in r.dialogs: for d in r.dialogs:
dialogs[d.top_message] = d dialogs[utils.get_peer_id(d.peer, True)] = d
for m in r.messages: for m in r.messages:
messages[m.id] = m messages[m.id] = m
@ -277,9 +278,20 @@ class TelegramClient(TelegramBareClient):
[utils.find_user_or_chat(d.peer, entities, entities) for d in ds] [utils.find_user_or_chat(d.peer, entities, entities) for d in ds]
) )
# endregion async def get_drafts(self): # TODO: Ability to provide a `filter`
"""
Gets all open draft messages.
# region Message requests Returns a list of custom `Draft` objects that are easy to work with: You can call
`draft.set_message('text')` to change the message, or delete it through `draft.delete()`.
:return List[telethon.tl.custom.Draft]: A list of open drafts
"""
response = await self(GetAllDraftsRequest())
self.session.process_entities(response)
self.session.generate_sequence(response.seq)
drafts = [Draft._from_update(self, u) for u in response.updates]
return drafts
async def send_message(self, async def send_message(self,
entity, entity,
@ -322,7 +334,7 @@ class TelegramClient(TelegramBareClient):
break break
for update in result.updates: for update in result.updates:
if isinstance(update, UpdateNewMessage): if isinstance(update, (UpdateNewChannelMessage, UpdateNewMessage)):
if update.message.id == msg_id: if update.message.id == msg_id:
return update.message return update.message
@ -463,9 +475,13 @@ class TelegramClient(TelegramBareClient):
async def send_file(self, entity, file, caption='', async def send_file(self, entity, file, caption='',
force_document=False, progress_callback=None, force_document=False, progress_callback=None,
reply_to=None, reply_to=None,
attributes=None,
**kwargs): **kwargs):
"""Sends a file to the specified entity. """Sends a file to the specified entity.
The file may either be a path, a byte array, or a stream. The file may either be a path, a byte array, or a stream.
Note that if a byte array or a stream is given, a filename
or its type won't be inferred, and it will be sent as an
"unnamed application/octet-stream".
An optional caption can also be specified for said file. An optional caption can also be specified for said file.
@ -482,6 +498,10 @@ class TelegramClient(TelegramBareClient):
The "reply_to" parameter works exactly as the one on .send_message. The "reply_to" parameter works exactly as the one on .send_message.
If "attributes" is set to be a list of DocumentAttribute's, these
will override the automatically inferred ones (so that you can
modify the file name of the file sent for instance).
If "is_voice_note" in kwargs, despite its value, and the file is If "is_voice_note" in kwargs, despite its value, and the file is
sent as a document, it will be sent as a voice note. sent as a document, it will be sent as a voice note.
@ -512,16 +532,28 @@ class TelegramClient(TelegramBareClient):
# Determine mime-type and attributes # Determine mime-type and attributes
# Take the first element by using [0] since it returns a tuple # Take the first element by using [0] since it returns a tuple
mime_type = guess_type(file)[0] mime_type = guess_type(file)[0]
attributes = [ attr_dict = {
DocumentAttributeFilename:
DocumentAttributeFilename(os.path.basename(file)) DocumentAttributeFilename(os.path.basename(file))
# TODO If the input file is an audio, find out: # TODO If the input file is an audio, find out:
# Performer and song title and add DocumentAttributeAudio # Performer and song title and add DocumentAttributeAudio
] }
else: else:
attributes = [DocumentAttributeFilename('unnamed')] attr_dict = {
DocumentAttributeFilename:
DocumentAttributeFilename('unnamed')
}
if 'is_voice_note' in kwargs: if 'is_voice_note' in kwargs:
attributes.append(DocumentAttributeAudio(0, voice=True)) attr_dict[DocumentAttributeAudio] = \
DocumentAttributeAudio(0, voice=True)
# Now override the attributes if any. As we have a dict of
# {cls: instance}, we can override any class with the list
# of attributes provided by the user easily.
if attributes:
for a in attributes:
attr_dict[type(a)] = a
# Ensure we have a mime type, any; but it cannot be None # Ensure we have a mime type, any; but it cannot be None
# 'The "octet-stream" subtype is used to indicate that a body # 'The "octet-stream" subtype is used to indicate that a body
@ -532,7 +564,7 @@ class TelegramClient(TelegramBareClient):
media = InputMediaUploadedDocument( media = InputMediaUploadedDocument(
file=file_handle, file=file_handle,
mime_type=mime_type, mime_type=mime_type,
attributes=attributes, attributes=list(attr_dict.values()),
caption=caption caption=caption
) )
@ -852,16 +884,16 @@ class TelegramClient(TelegramBareClient):
# crc32(b'InputPeer') and crc32(b'Peer') # crc32(b'InputPeer') and crc32(b'Peer')
type(entity).SUBCLASS_OF_ID in (0xc91c90b6, 0x2d45687)): type(entity).SUBCLASS_OF_ID in (0xc91c90b6, 0x2d45687)):
ie = await self.get_input_entity(entity) ie = await self.get_input_entity(entity)
result = None
if isinstance(ie, InputPeerUser): if isinstance(ie, InputPeerUser):
result = await self(GetUsersRequest([ie])) await self(GetUsersRequest([ie]))
elif isinstance(ie, InputPeerChat): elif isinstance(ie, InputPeerChat):
result = await self(GetChatsRequest([ie.chat_id])) await self(GetChatsRequest([ie.chat_id]))
elif isinstance(ie, InputPeerChannel): elif isinstance(ie, InputPeerChannel):
result = await self(GetChannelsRequest([ie])) await self(GetChannelsRequest([ie]))
if result:
self.session.process_entities(result)
try: try:
# session.process_entities has been called in the MtProtoSender
# with the result of these calls, so they should now be on the
# entities database.
return self.session.entities[ie] return self.session.entities[ie]
except KeyError: except KeyError:
pass pass
@ -880,10 +912,11 @@ class TelegramClient(TelegramBareClient):
phone = EntityDatabase.parse_phone(string) phone = EntityDatabase.parse_phone(string)
if phone: if phone:
entity = phone entity = phone
self.session.process_entities(await self(GetContactsRequest(0))) await self(GetContactsRequest(0))
else: else:
entity = string.strip('@').lower() entity = string.strip('@').lower()
self.session.process_entities(await self(ResolveUsernameRequest(entity))) await self(ResolveUsernameRequest(entity))
# MtProtoSender will call .process_entities on the requests made
try: try:
return self.session.entities[entity] return self.session.entities[entity]
@ -930,9 +963,17 @@ class TelegramClient(TelegramBareClient):
) )
if self.session.save_entities: if self.session.save_entities:
# Not found, look in the dialogs (this will save the users) # Not found, look in the latest dialogs.
await self.get_dialogs(limit=None) # This is useful if for instance someone just sent a message but
# the updates didn't specify who, as this person or chat should
# be in the latest dialogs.
await self(GetDialogsRequest(
offset_date=None,
offset_id=0,
offset_peer=InputPeerEmpty(),
limit=0,
exclude_pinned=True
))
try: try:
return self.session.entities.get_input_entity(peer) return self.session.entities.get_input_entity(peer)
except KeyError: except KeyError:

View File

@ -0,0 +1 @@
from .draft import Draft

View File

@ -0,0 +1,80 @@
from ..functions.messages import SaveDraftRequest
from ..types import UpdateDraftMessage
class Draft:
"""
Custom class that encapsulates a draft on the Telegram servers, providing
an abstraction to change the message conveniently. The library will return
instances of this class when calling `client.get_drafts()`.
"""
def __init__(self, client, peer, draft):
self._client = client
self._peer = peer
self.text = draft.message
self.date = draft.date
self.no_webpage = draft.no_webpage
self.reply_to_msg_id = draft.reply_to_msg_id
self.entities = draft.entities
@classmethod
def _from_update(cls, client, update):
if not isinstance(update, UpdateDraftMessage):
raise ValueError(
'You can only create a new `Draft` from a corresponding '
'`UpdateDraftMessage` object.'
)
return cls(client=client, peer=update.peer, draft=update.draft)
@property
def entity(self):
return self._client.get_entity(self._peer)
@property
def input_entity(self):
return self._client.get_input_entity(self._peer)
def set_message(self, text, no_webpage=None, reply_to_msg_id=None, entities=None):
"""
Changes the draft message on the Telegram servers. The changes are
reflected in this object. Changing only individual attributes like for
example the `reply_to_msg_id` should be done by providing the current
values of this object, like so:
draft.set_message(
draft.text,
no_webpage=draft.no_webpage,
reply_to_msg_id=NEW_VALUE,
entities=draft.entities
)
:param str text: New text of the draft
:param bool no_webpage: Whether to attach a web page preview
:param int reply_to_msg_id: Message id to reply to
:param list entities: A list of formatting entities
:return bool: `True` on success
"""
result = self._client(SaveDraftRequest(
peer=self._peer,
message=text,
no_webpage=no_webpage,
reply_to_msg_id=reply_to_msg_id,
entities=entities
))
if result:
self.text = text
self.no_webpage = no_webpage
self.reply_to_msg_id = reply_to_msg_id
self.entities = entities
return result
def delete(self):
"""
Deletes this draft
:return bool: `True` on success
"""
return self.set_message(text='')

View File

@ -70,9 +70,7 @@ class EntityDatabase:
getattr(p, 'access_hash', 0) # chats won't have hash getattr(p, 'access_hash', 0) # chats won't have hash
if self.enabled_full: if self.enabled_full:
if isinstance(e, User) \ if isinstance(e, (User, Chat, Channel)):
or isinstance(e, Chat) \
or isinstance(e, Channel):
new.append(e) new.append(e)
except ValueError: except ValueError:
pass pass
@ -118,47 +116,64 @@ class EntityDatabase:
if phone: if phone:
self._username_id[phone] = marked_id self._username_id[phone] = marked_id
def __getitem__(self, key): def _parse_key(self, key):
"""Accepts a digit only string as phone number, """Parses the given string, integer or TLObject key into a
otherwise it's treated as an username. marked user ID ready for use on self._entities.
If an integer is given, it's treated as the ID of the desired User. If a callable key is given, the entity will be passed to the
The ID given won't try to be guessed as the ID of a chat or channel, function, and if it returns a true-like value, the marked ID
as there may be an user with that ID, and it would be unreliable. for such entity will be returned.
If a Peer is given (PeerUser, PeerChat, PeerChannel), Raises ValueError if it cannot be parsed.
its specific entity is retrieved as User, Chat or Channel.
Note that megagroups are channels with .megagroup = True.
""" """
if isinstance(key, str): if isinstance(key, str):
phone = EntityDatabase.parse_phone(key) phone = EntityDatabase.parse_phone(key)
try:
if phone: if phone:
return self._phone_id[phone] return self._phone_id[phone]
else: else:
key = key.lstrip('@').lower() return self._username_id[key.lstrip('@').lower()]
return self._entities[self._username_id[key]] except KeyError as e:
raise ValueError() from e
if isinstance(key, int): if isinstance(key, int):
return self._entities[key] # normal IDs are assumed users return key # normal IDs are assumed users
if isinstance(key, TLObject): if isinstance(key, TLObject):
sc = type(key).SUBCLASS_OF_ID return utils.get_peer_id(key, add_mark=True)
if sc == 0x2d45687:
# Subclass of "Peer"
return self._entities[utils.get_peer_id(key, add_mark=True)]
elif sc in {0x2da17977, 0xc5af5d94, 0x6d44b7db}:
# Subclass of "User", "Chat" or "Channel"
return key
raise KeyError(key) if callable(key):
for k, v in self._entities.items():
if key(v):
return k
raise ValueError()
def __getitem__(self, key):
"""See the ._parse_key() docstring for possible values of the key"""
try:
return self._entities[self._parse_key(key)]
except (ValueError, KeyError) as e:
raise KeyError(key) from e
def __delitem__(self, key): def __delitem__(self, key):
target = self[key] try:
del self._entities[key] old = self._entities.pop(self._parse_key(key))
if getattr(target, 'username'): # Try removing the username and phone (if pop didn't fail),
del self._username_id[target.username] # since the entity may have no username or phone, just ignore
# errors. It should be there if we popped the entity correctly.
try:
del self._username_id[getattr(old, 'username', None)]
except KeyError:
pass
# TODO Allow search by name by tokenizing the input and return a list try:
del self._phone_id[getattr(old, 'phone', None)]
except KeyError:
pass
except (ValueError, KeyError) as e:
raise KeyError(key) from e
@staticmethod @staticmethod
def parse_phone(phone): def parse_phone(phone):
@ -172,8 +187,10 @@ class EntityDatabase:
def get_input_entity(self, peer): def get_input_entity(self, peer):
try: try:
i, k = utils.get_peer_id(peer, add_mark=True, get_kind=True) i = utils.get_peer_id(peer, add_mark=True)
h = self._input_entities[i] h = self._input_entities[i] # we store the IDs marked
i, k = utils.resolve_id(i) # removes the mark and returns kind
if k == PeerUser: if k == PeerUser:
return InputPeerUser(i, h) return InputPeerUser(i, h)
elif k == PeerChat: elif k == PeerChat:

View File

@ -34,5 +34,5 @@ class GzipPacked(TLObject):
@staticmethod @staticmethod
def read(reader): def read(reader):
reader.read_int(signed=False) # code assert reader.read_int(signed=False) == GzipPacked.CONSTRUCTOR_ID
return gzip.decompress(reader.tgread_bytes()) return gzip.decompress(reader.tgread_bytes())

View File

@ -36,6 +36,7 @@ class Session:
self.lang_pack = session.lang_pack self.lang_pack = session.lang_pack
self.report_errors = session.report_errors self.report_errors = session.report_errors
self.save_entities = session.save_entities self.save_entities = session.save_entities
self.flood_sleep_threshold = session.flood_sleep_threshold
else: # str / None else: # str / None
self.session_user_id = session_user_id self.session_user_id = session_user_id
@ -49,6 +50,7 @@ class Session:
self.lang_pack = '' self.lang_pack = ''
self.report_errors = True self.report_errors = True
self.save_entities = True self.save_entities = True
self.flood_sleep_threshold = 60
self.id = helpers.generate_random_long(signed=False) self.id = helpers.generate_random_long(signed=False)
self._sequence = 0 self._sequence = 0

View File

@ -1,3 +1,4 @@
from datetime import datetime
from threading import Event from threading import Event
@ -19,18 +20,14 @@ class TLObject:
""" """
if indent is None: if indent is None:
if isinstance(obj, TLObject): if isinstance(obj, TLObject):
children = obj.to_dict(recursive=False) return '{}({})'.format(type(obj).__name__, ', '.join(
if children: '{}={}'.format(k, TLObject.pretty_format(v))
return '{}: {}'.format( for k, v in obj.to_dict(recursive=False).items()
type(obj).__name__, TLObject.pretty_format(children) ))
)
else:
return type(obj).__name__
if isinstance(obj, dict): if isinstance(obj, dict):
return '{{{}}}'.format(', '.join( return '{{{}}}'.format(', '.join(
'{}: {}'.format( '{}: {}'.format(k, TLObject.pretty_format(v))
k, TLObject.pretty_format(v) for k, v in obj.items()
) for k, v in obj.items()
)) ))
elif isinstance(obj, str) or isinstance(obj, bytes): elif isinstance(obj, str) or isinstance(obj, bytes):
return repr(obj) return repr(obj)
@ -38,31 +35,36 @@ class TLObject:
return '[{}]'.format( return '[{}]'.format(
', '.join(TLObject.pretty_format(x) for x in obj) ', '.join(TLObject.pretty_format(x) for x in obj)
) )
elif isinstance(obj, datetime):
return 'datetime.fromtimestamp({})'.format(obj.timestamp())
else: else:
return str(obj) return repr(obj)
else: else:
result = [] result = []
if isinstance(obj, TLObject): if isinstance(obj, TLObject) or isinstance(obj, dict):
if isinstance(obj, dict):
d = obj
start, end, sep = '{', '}', ': '
else:
d = obj.to_dict(recursive=False)
start, end, sep = '(', ')', '='
result.append(type(obj).__name__) result.append(type(obj).__name__)
children = obj.to_dict(recursive=False)
if children:
result.append(': ')
result.append(TLObject.pretty_format(
obj.to_dict(recursive=False), indent
))
elif isinstance(obj, dict): result.append(start)
result.append('{\n') if d:
result.append('\n')
indent += 1 indent += 1
for k, v in obj.items(): for k, v in d.items():
result.append('\t' * indent) result.append('\t' * indent)
result.append(k) result.append(k)
result.append(': ') result.append(sep)
result.append(TLObject.pretty_format(v, indent)) result.append(TLObject.pretty_format(v, indent))
result.append(',\n') result.append(',\n')
result.pop() # last ',\n'
indent -= 1 indent -= 1
result.append('\n')
result.append('\t' * indent) result.append('\t' * indent)
result.append('}') result.append(end)
elif isinstance(obj, str) or isinstance(obj, bytes): elif isinstance(obj, str) or isinstance(obj, bytes):
result.append(repr(obj)) result.append(repr(obj))
@ -78,8 +80,13 @@ class TLObject:
result.append('\t' * indent) result.append('\t' * indent)
result.append(']') result.append(']')
elif isinstance(obj, datetime):
result.append('datetime.fromtimestamp(')
result.append(repr(obj.timestamp()))
result.append(')')
else: else:
result.append(str(obj)) result.append(repr(obj))
return ''.join(result) return ''.join(result)
@ -121,5 +128,6 @@ class TLObject:
def to_bytes(self): def to_bytes(self):
return b'' return b''
def on_response(self, reader): @staticmethod
pass def from_reader(reader):
return TLObject()

View File

@ -1,4 +1,5 @@
import logging import logging
import pickle
from collections import deque from collections import deque
from datetime import datetime from datetime import datetime
from threading import RLock, Event, Thread from threading import RLock, Event, Thread
@ -27,6 +28,7 @@ class UpdateState:
self._updates_lock = RLock() self._updates_lock = RLock()
self._updates_available = Event() self._updates_available = Event()
self._updates = deque() self._updates = deque()
self._latest_updates = deque(maxlen=10)
self._logger = logging.getLogger(__name__) self._logger = logging.getLogger(__name__)
@ -141,6 +143,26 @@ class UpdateState:
self._state.pts = pts self._state.pts = pts
# TODO There must be a better way to handle updates rather than
# keeping a queue with the latest updates only, and handling
# the 'pts' correctly should be enough. However some updates
# like UpdateUserStatus (even inside UpdateShort) will be called
# repeatedly very often if invoking anything inside an update
# handler. TODO Figure out why.
"""
client = TelegramClient('anon', api_id, api_hash, update_workers=1)
client.connect()
def handle(u):
client.get_me()
client.add_update_handler(handle)
input('Enter to exit.')
"""
data = pickle.dumps(update.to_dict())
if data in self._latest_updates:
return # Duplicated too
self._latest_updates.append(data)
if type(update).SUBCLASS_OF_ID == 0x8af52aac: # crc32(b'Updates') if type(update).SUBCLASS_OF_ID == 0x8af52aac: # crc32(b'Updates')
# Expand "Updates" into "Update", and pass these to callbacks. # Expand "Updates" into "Update", and pass these to callbacks.
# Since .users and .chats have already been processed, we # Since .users and .chats have already been processed, we
@ -149,8 +171,7 @@ class UpdateState:
self._updates.append(update.update) self._updates.append(update.update)
self._updates_available.set() self._updates_available.set()
elif isinstance(update, tl.Updates) or \ elif isinstance(update, (tl.Updates, tl.UpdatesCombined)):
isinstance(update, tl.UpdatesCombined):
self._updates.extend(update.updates) self._updates.extend(update.updates)
self._updates_available.set() self._updates_available.set()

View File

@ -20,8 +20,8 @@ from .tl.types import (
GeoPointEmpty, InputGeoPointEmpty, Photo, InputPhoto, PhotoEmpty, GeoPointEmpty, InputGeoPointEmpty, Photo, InputPhoto, PhotoEmpty,
InputPhotoEmpty, FileLocation, ChatPhotoEmpty, UserProfilePhotoEmpty, InputPhotoEmpty, FileLocation, ChatPhotoEmpty, UserProfilePhotoEmpty,
FileLocationUnavailable, InputMediaUploadedDocument, FileLocationUnavailable, InputMediaUploadedDocument,
InputMediaUploadedPhoto, InputMediaUploadedPhoto, DocumentAttributeFilename, photos
DocumentAttributeFilename) )
def get_display_name(entity): def get_display_name(entity):
@ -37,7 +37,7 @@ def get_display_name(entity):
else: else:
return '(No name)' return '(No name)'
if isinstance(entity, Chat) or isinstance(entity, Channel): if isinstance(entity, (Chat, Channel)):
return entity.title return entity.title
return '(unknown)' return '(unknown)'
@ -50,8 +50,7 @@ def get_extension(media):
"""Gets the corresponding extension for any Telegram media""" """Gets the corresponding extension for any Telegram media"""
# Photos are always compressed as .jpg by Telegram # Photos are always compressed as .jpg by Telegram
if (isinstance(media, UserProfilePhoto) or isinstance(media, ChatPhoto) or if isinstance(media, (UserProfilePhoto, ChatPhoto, MessageMediaPhoto)):
isinstance(media, MessageMediaPhoto)):
return '.jpg' return '.jpg'
# Documents will come with a mime type # Documents will come with a mime type
@ -87,12 +86,10 @@ def get_input_peer(entity, allow_self=True):
else: else:
return InputPeerUser(entity.id, entity.access_hash) return InputPeerUser(entity.id, entity.access_hash)
if any(isinstance(entity, c) for c in ( if isinstance(entity, (Chat, ChatEmpty, ChatForbidden)):
Chat, ChatEmpty, ChatForbidden)):
return InputPeerChat(entity.id) return InputPeerChat(entity.id)
if any(isinstance(entity, c) for c in ( if isinstance(entity, (Channel, ChannelForbidden)):
Channel, ChannelForbidden)):
return InputPeerChannel(entity.id, entity.access_hash) return InputPeerChannel(entity.id, entity.access_hash)
# Less common cases # Less common cases
@ -122,7 +119,7 @@ def get_input_channel(entity):
if type(entity).SUBCLASS_OF_ID == 0x40f202fd: # crc32(b'InputChannel') if type(entity).SUBCLASS_OF_ID == 0x40f202fd: # crc32(b'InputChannel')
return entity return entity
if isinstance(entity, Channel) or isinstance(entity, ChannelForbidden): if isinstance(entity, (Channel, ChannelForbidden)):
return InputChannel(entity.id, entity.access_hash) return InputChannel(entity.id, entity.access_hash)
if isinstance(entity, InputPeerChannel): if isinstance(entity, InputPeerChannel):
@ -188,6 +185,9 @@ def get_input_photo(photo):
if type(photo).SUBCLASS_OF_ID == 0x846363e0: # crc32(b'InputPhoto') if type(photo).SUBCLASS_OF_ID == 0x846363e0: # crc32(b'InputPhoto')
return photo return photo
if isinstance(photo, photos.Photo):
photo = photo.photo
if isinstance(photo, Photo): if isinstance(photo, Photo):
return InputPhoto(id=photo.id, access_hash=photo.access_hash) return InputPhoto(id=photo.id, access_hash=photo.access_hash)
@ -263,7 +263,7 @@ def get_input_media(media, user_caption=None, is_photo=False):
if isinstance(media, MessageMediaGame): if isinstance(media, MessageMediaGame):
return InputMediaGame(id=media.game.id) return InputMediaGame(id=media.game.id)
if isinstance(media, ChatPhoto) or isinstance(media, UserProfilePhoto): if isinstance(media, (ChatPhoto, UserProfilePhoto)):
if isinstance(media.photo_big, FileLocationUnavailable): if isinstance(media.photo_big, FileLocationUnavailable):
return get_input_media(media.photo_small, is_photo=True) return get_input_media(media.photo_small, is_photo=True)
else: else:
@ -288,10 +288,9 @@ def get_input_media(media, user_caption=None, is_photo=False):
venue_id=media.venue_id venue_id=media.venue_id
) )
if any(isinstance(media, t) for t in ( if isinstance(media, (
MessageMediaEmpty, MessageMediaUnsupported, MessageMediaEmpty, MessageMediaUnsupported,
FileLocationUnavailable, ChatPhotoEmpty, ChatPhotoEmpty, UserProfilePhotoEmpty, FileLocationUnavailable)):
UserProfilePhotoEmpty)):
return InputMediaEmpty() return InputMediaEmpty()
if isinstance(media, Message): if isinstance(media, Message):
@ -300,16 +299,14 @@ def get_input_media(media, user_caption=None, is_photo=False):
_raise_cast_fail(media, 'InputMedia') _raise_cast_fail(media, 'InputMedia')
def get_peer_id(peer, add_mark=False, get_kind=False): def get_peer_id(peer, add_mark=False):
"""Finds the ID of the given peer, and optionally converts it to """Finds the ID of the given peer, and optionally converts it to
the "bot api" format if 'add_mark' is set to True. the "bot api" format if 'add_mark' is set to True.
If 'get_kind', the kind will be returned as a second value.
""" """
# First we assert it's a Peer TLObject, or early return for integers # First we assert it's a Peer TLObject, or early return for integers
if not isinstance(peer, TLObject): if not isinstance(peer, TLObject):
if isinstance(peer, int): if isinstance(peer, int):
return (peer, PeerUser) if get_kind else peer return peer
else: else:
_raise_cast_fail(peer, 'int') _raise_cast_fail(peer, 'int')
@ -318,25 +315,20 @@ def get_peer_id(peer, add_mark=False, get_kind=False):
peer = get_input_peer(peer, allow_self=False) peer = get_input_peer(peer, allow_self=False)
# Set the right ID/kind, or raise if the TLObject is not recognised # Set the right ID/kind, or raise if the TLObject is not recognised
i, k = None, None if isinstance(peer, (PeerUser, InputPeerUser)):
if isinstance(peer, PeerUser) or isinstance(peer, InputPeerUser): return peer.user_id
i, k = peer.user_id, PeerUser elif isinstance(peer, (PeerChat, InputPeerChat)):
elif isinstance(peer, PeerChat) or isinstance(peer, InputPeerChat): return -peer.chat_id if add_mark else peer.chat_id
i, k = peer.chat_id, PeerChat elif isinstance(peer, (PeerChannel, InputPeerChannel)):
elif isinstance(peer, PeerChannel) or isinstance(peer, InputPeerChannel): i = peer.channel_id
i, k = peer.channel_id, PeerChannel
else:
_raise_cast_fail(peer, 'int')
if add_mark: if add_mark:
if k == PeerChat:
i = -i
elif k == PeerChannel:
# Concat -100 through math tricks, .to_supergroup() on Madeline # Concat -100 through math tricks, .to_supergroup() on Madeline
# IDs will be strictly positive -> log works # IDs will be strictly positive -> log works
i = -(i + pow(10, math.floor(math.log10(i) + 3))) return -(i + pow(10, math.floor(math.log10(i) + 3)))
else:
return i
return (i, k) if get_kind else i # return kind only if get_kind _raise_cast_fail(peer, 'int')
def resolve_id(marked_id): def resolve_id(marked_id):
@ -375,11 +367,7 @@ def find_user_or_chat(peer, users, chats):
def get_appropriated_part_size(file_size): def get_appropriated_part_size(file_size):
"""Gets the appropriated part size when uploading or downloading files, """Gets the appropriated part size when uploading or downloading files,
given an initial file size""" given an initial file size"""
if file_size <= 1048576: # 1MB if file_size <= 104857600: # 100MB
return 32
if file_size <= 10485760: # 10MB
return 64
if file_size <= 393216000: # 375MB
return 128 return 128
if file_size <= 786432000: # 750MB if file_size <= 786432000: # 750MB
return 256 return 256

View File

@ -143,7 +143,7 @@ class TLGenerator:
builder.writeln( builder.writeln(
'from {}.utils import get_input_peer, ' 'from {}.utils import get_input_peer, '
'get_input_channel, get_input_user, ' 'get_input_channel, get_input_user, '
'get_input_media'.format('.' * depth) 'get_input_media, get_input_photo'.format('.' * depth)
) )
# Import 'os' for those needing access to 'os.urandom()' # Import 'os' for those needing access to 'os.urandom()'
@ -335,31 +335,27 @@ class TLGenerator:
builder.writeln('))') builder.writeln('))')
builder.end_block() builder.end_block()
# Write the empty() function, which returns an "empty" # Write the static from_reader(reader) function
# instance, in which all attributes are set to None
builder.writeln('@staticmethod') builder.writeln('@staticmethod')
builder.writeln('def empty():') builder.writeln('def from_reader(reader):')
for arg in tlobject.args:
TLGenerator.write_read_code(
builder, arg, tlobject.args, name='_' + arg.name
)
builder.writeln('return {}({})'.format( builder.writeln('return {}({})'.format(
tlobject.class_name(), ', '.join('None' for _ in range(len(args))) tlobject.class_name(), ', '.join(
'{0}=_{0}'.format(a.name) for a in tlobject.sorted_args()
if not a.flag_indicator and not a.generic_definition
)
)) ))
builder.end_block() builder.end_block()
# Write the on_response(self, reader) function # Only requests can have a different response that's not their
builder.writeln('def on_response(self, reader):') # serialized body, that is, we'll be setting their .result.
# Do not read constructor's ID, since
# that's already been read somewhere else
if tlobject.is_function: if tlobject.is_function:
builder.writeln('def on_response(self, reader):')
TLGenerator.write_request_result_code(builder, tlobject) TLGenerator.write_request_result_code(builder, tlobject)
else:
if tlobject.args:
for arg in tlobject.args:
TLGenerator.write_onresponse_code(
builder, arg, tlobject.args
)
else:
# If there were no arguments, we still need an
# on_response method, and hence "pass" if empty
builder.writeln('pass')
builder.end_block() builder.end_block()
# Write the __str__(self) and stringify(self) functions # Write the __str__(self) and stringify(self) functions
@ -406,6 +402,8 @@ class TLGenerator:
TLGenerator.write_get_input(builder, arg, 'get_input_user') TLGenerator.write_get_input(builder, arg, 'get_input_user')
elif arg.type == 'InputMedia' and tlobject.is_function: elif arg.type == 'InputMedia' and tlobject.is_function:
TLGenerator.write_get_input(builder, arg, 'get_input_media') TLGenerator.write_get_input(builder, arg, 'get_input_media')
elif arg.type == 'InputPhoto' and tlobject.is_function:
TLGenerator.write_get_input(builder, arg, 'get_input_photo')
else: else:
builder.writeln('self.{0} = {0}'.format(arg.name)) builder.writeln('self.{0} = {0}'.format(arg.name))
@ -549,9 +547,10 @@ class TLGenerator:
return True # Something was written return True # Something was written
@staticmethod @staticmethod
def write_onresponse_code(builder, arg, args, name=None): def write_read_code(builder, arg, args, name):
""" """
Writes the receive code for the given argument Writes the read code for the given argument, setting the
arg.name variable to its read value.
:param builder: The source code builder :param builder: The source code builder
:param arg: The argument to write :param arg: The argument to write
@ -565,12 +564,17 @@ class TLGenerator:
if arg.generic_definition: if arg.generic_definition:
return # Do nothing, this only specifies a later type return # Do nothing, this only specifies a later type
if name is None:
name = 'self.{}'.format(arg.name)
# The argument may be a flag, only write that flag was given! # The argument may be a flag, only write that flag was given!
was_flag = False was_flag = False
if arg.is_flag: if arg.is_flag:
# Treat 'true' flags as a special case, since they're true if
# they're set, and nothing else needs to actually be read.
if 'true' == arg.type:
builder.writeln(
'{} = bool(flags & {})'.format(name, 1 << arg.flag_index)
)
return
was_flag = True was_flag = True
builder.writeln('if flags & {}:'.format( builder.writeln('if flags & {}:'.format(
1 << arg.flag_index 1 << arg.flag_index
@ -585,11 +589,10 @@ class TLGenerator:
builder.writeln("reader.read_int()") builder.writeln("reader.read_int()")
builder.writeln('{} = []'.format(name)) builder.writeln('{} = []'.format(name))
builder.writeln('_len = reader.read_int()') builder.writeln('for _ in range(reader.read_int()):')
builder.writeln('for _ in range(_len):')
# Temporary disable .is_vector, not to enter this if again # Temporary disable .is_vector, not to enter this if again
arg.is_vector = False arg.is_vector = False
TLGenerator.write_onresponse_code(builder, arg, args, name='_x') TLGenerator.write_read_code(builder, arg, args, name='_x')
builder.writeln('{}.append(_x)'.format(name)) builder.writeln('{}.append(_x)'.format(name))
arg.is_vector = True arg.is_vector = True
@ -642,7 +645,10 @@ class TLGenerator:
builder.end_block() builder.end_block()
if was_flag: if was_flag:
builder.end_block() builder.current_indent -= 1
builder.writeln('else:')
builder.writeln('{} = None'.format(name))
builder.current_indent -= 1
# Restore .is_flag # Restore .is_flag
arg.is_flag = True arg.is_flag = True

View File

@ -107,17 +107,17 @@ class CryptoTests(unittest.TestCase):
@staticmethod @staticmethod
def test_generate_key_data_from_nonce(): def test_generate_key_data_from_nonce():
server_nonce = b'I am the server nonce.' server_nonce = int.from_bytes(b'The 16-bit nonce', byteorder='little')
new_nonce = b'I am a new calculated nonce.' new_nonce = int.from_bytes(b'The new, calculated 32-bit nonce', byteorder='little')
key, iv = utils.generate_key_data_from_nonce(server_nonce, new_nonce) key, iv = utils.generate_key_data_from_nonce(server_nonce, new_nonce)
expected_key = b'?\xc4\xbd\xdf\rWU\x8a\xf5\x0f+V\xdc\x96up\x1d\xeeG\x00\x81|\x1eg\x8a\x8f{\xf0y\x80\xda\xde' expected_key = b'/\xaa\x7f\xa1\xfcs\xef\xa0\x99zh\x03M\xa4\x8e\xb4\xab\x0eE]b\x95|\xfe\xc0\xf8\x1f\xd4\xa0\xd4\xec\x91'
expected_iv = b'Q\x9dpZ\xb7\xdd\xcb\x82_\xfa\xf4\x90\xecn\x10\x9cD\xd2\x01\x8d\x83\xa0\xa4^\xb8\x91,\x7fI am' expected_iv = b'\xf7\xae\xe3\xc8+=\xc2\xb8\xd1\xe1\x1b\x0e\x10\x07\x9fn\x9e\xdc\x960\x05\xf9\xea\xee\x8b\xa1h The '
assert key == expected_key, 'Key ("{}") does not equal expected ("{}")'.format( assert key == expected_key, 'Key ("{}") does not equal expected ("{}")'.format(
key, expected_key) key, expected_key)
assert iv == expected_iv, 'Key ("{}") does not equal expected ("{}")'.format( assert iv == expected_iv, 'IV ("{}") does not equal expected ("{}")'.format(
key, expected_iv) iv, expected_iv)
@staticmethod @staticmethod
def test_fingerprint_from_key(): def test_fingerprint_from_key():