Fixed bugs when handling updates

GZipPacked updates are now handled correctly.
Also, fixed another bug which did not allow
resending a request when BadServerSalt occured.
This commit is contained in:
Lonami 2016-09-10 18:05:20 +02:00
parent 0068c0fd8b
commit 5b4be5b85e
3 changed files with 100 additions and 80 deletions

View File

@ -42,7 +42,13 @@ if __name__ == '__main__':
print('{}. {}'.format(i, display)) print('{}. {}'.format(i, display))
# Let the user decide who they want to talk to # Let the user decide who they want to talk to
i = None
while i is None:
try:
i = int(input('Who do you want to send messages to (0 to exit)?: ')) - 1 i = int(input('Who do you want to send messages to (0 to exit)?: ')) - 1
except ValueError:
pass
if i == -1: if i == -1:
break break

View File

@ -14,7 +14,7 @@ from tl.all_tlobjects import tlobjects
class MtProtoSender: class MtProtoSender:
"""MTProto Mobile Protocol sender (https://core.telegram.org/mtproto/description)""" """MTProto Mobile Protocol sender (https://core.telegram.org/mtproto/description)"""
def __init__(self, transport, session, check_updates=True): def __init__(self, transport, session):
self.transport = transport self.transport = transport
self.session = session self.session = session
@ -24,17 +24,13 @@ class MtProtoSender:
# Store a Lock instance to make this class safely multi-threaded # Store a Lock instance to make this class safely multi-threaded
self.lock = Lock() self.lock = Lock()
if check_updates:
self.updates_thread = Thread(target=self.updates_thread_method, name='Updates thread') self.updates_thread = Thread(target=self.updates_thread_method, name='Updates thread')
self.updates_thread_running = True self.updates_thread_running = False
self.updates_thread_receiving = False self.updates_thread_receiving = False
self.updates_thread.start()
def disconnect(self): def disconnect(self):
"""Disconnects and **stops all the running threads** if any""" """Disconnects and **stops all the running threads** if any"""
self.updates_thread_running = False self.set_listen_for_updates(enabled=False)
self.transport.cancel_receive()
self.transport.close() self.transport.close()
def add_update_handler(self, handler): def add_update_handler(self, handler):
@ -54,7 +50,7 @@ class MtProtoSender:
# region Send and receive # region Send and receive
def send(self, request): def send(self, request, resend=False):
"""Sends the specified MTProtoRequest, previously sending any message """Sends the specified MTProtoRequest, previously sending any message
which needed confirmation. This also pauses the updates thread""" which needed confirmation. This also pauses the updates thread"""
@ -64,7 +60,8 @@ class MtProtoSender:
if self.updates_thread_receiving: if self.updates_thread_receiving:
self.transport.cancel_receive() self.transport.cancel_receive()
# Now only us can be using this method # Now only us can be using this method if we're not resending
if not resend:
self.lock.acquire() self.lock.acquire()
# If any message needs confirmation send an AckRequest first # If any message needs confirmation send an AckRequest first
@ -161,9 +158,8 @@ class MtProtoSender:
return message, remote_msg_id, remote_sequence return message, remote_msg_id, remote_sequence
def process_msg(self, msg_id, sequence, reader, request, only_updates=False): def process_msg(self, msg_id, sequence, reader, request=None):
"""Processes and handles a Telegram message. Optionally, this """Processes and handles a Telegram message"""
function will only check for updates (hence the request can be None)"""
# TODO Check salt, session_id and sequence_number # TODO Check salt, session_id and sequence_number
self.need_confirmation.append(msg_id) self.need_confirmation.append(msg_id)
@ -171,10 +167,10 @@ class MtProtoSender:
code = reader.read_int(signed=False) code = reader.read_int(signed=False)
reader.seek(-4) reader.seek(-4)
# The following codes are "parsed manually" (and do not refer to an update) # The following codes are "parsed manually"
if not only_updates: if code == 0xf35c6d01: # rpc_result, (response of an RPC call, i.e., we sent a request)
if code == 0xf35c6d01: # rpc_result
return self.handle_rpc_result(msg_id, sequence, reader, request) return self.handle_rpc_result(msg_id, sequence, reader, request)
if code == 0x73f1f8dc: # msg_container if code == 0x73f1f8dc: # msg_container
return self.handle_container(msg_id, sequence, reader, request) return self.handle_container(msg_id, sequence, reader, request)
if code == 0x3072cfa1: # gzip_packed if code == 0x3072cfa1: # gzip_packed
@ -218,7 +214,7 @@ class MtProtoSender:
return False return False
def handle_bad_server_salt(self, msg_id, sequence, reader, mtproto_request): def handle_bad_server_salt(self, msg_id, sequence, reader, request):
code = reader.read_int(signed=False) code = reader.read_int(signed=False)
bad_msg_id = reader.read_long(signed=False) bad_msg_id = reader.read_long(signed=False)
bad_msg_seq_no = reader.read_int() bad_msg_seq_no = reader.read_int()
@ -227,8 +223,11 @@ class MtProtoSender:
self.session.salt = new_salt self.session.salt = new_salt
if request is None:
raise ValueError('Tried to handle a bad server salt with no request specified')
# Resend # Resend
self.send(mtproto_request) self.send(request, resend=True)
return True return True
@ -241,14 +240,13 @@ class MtProtoSender:
raise BadMessageError(error_code) raise BadMessageError(error_code)
def handle_rpc_result(self, msg_id, sequence, reader, request): def handle_rpc_result(self, msg_id, sequence, reader, request):
if not request:
raise ValueError('RPC results should only happen after a request was sent')
code = reader.read_int(signed=False) code = reader.read_int(signed=False)
request_id = reader.read_long(signed=False) request_id = reader.read_long(signed=False)
inner_code = reader.read_int(signed=False) inner_code = reader.read_int(signed=False)
if not request:
raise ValueError('Cannot handle RPC results if no request was specified. '
'This should only happen when the updates thread does not work properly.')
if request_id == request.msg_id: if request_id == request.msg_id:
request.confirm_received = True request.confirm_received = True
@ -266,27 +264,37 @@ class MtProtoSender:
else: else:
raise error raise error
else: else:
if inner_code == 0x3072cfa1: # GZip packed if inner_code == 0x3072cfa1: # GZip packed
unpacked_data = gzip.decompress(reader.tgread_bytes()) unpacked_data = gzip.decompress(reader.tgread_bytes())
with BinaryReader(unpacked_data) as compressed_reader: with BinaryReader(unpacked_data) as compressed_reader:
request.on_response(compressed_reader) request.on_response(compressed_reader)
else: else:
reader.seek(-4) reader.seek(-4)
request.on_response(reader) request.on_response(reader)
def handle_gzip_packed(self, msg_id, sequence, reader, mtproto_request): def handle_gzip_packed(self, msg_id, sequence, reader, request):
code = reader.read_int(signed=False) code = reader.read_int(signed=False)
packed_data = reader.tgread_bytes() packed_data = reader.tgread_bytes()
unpacked_data = gzip.decompress(packed_data) unpacked_data = gzip.decompress(packed_data)
with BinaryReader(unpacked_data) as compressed_reader: with BinaryReader(unpacked_data) as compressed_reader:
self.process_msg(msg_id, sequence, compressed_reader, mtproto_request) return self.process_msg(msg_id, sequence, compressed_reader, request)
# endregion # endregion
def set_listen_for_updates(self, enabled):
if enabled:
if not self.updates_thread_running:
self.updates_thread_running = True
self.updates_thread_receiving = False
self.updates_thread.start()
else:
self.updates_thread_running = False
if self.updates_thread_receiving:
self.transport.cancel_receive()
def updates_thread_method(self): def updates_thread_method(self):
"""This method will run until specified and listen for incoming updates""" """This method will run until specified and listen for incoming updates"""
while self.updates_thread_running: while self.updates_thread_running:
@ -297,7 +305,7 @@ class MtProtoSender:
message, remote_msg_id, remote_sequence = self.decode_msg(body) message, remote_msg_id, remote_sequence = self.decode_msg(body)
with BinaryReader(message) as reader: with BinaryReader(message) as reader:
self.process_msg(remote_msg_id, remote_sequence, reader, request=None, only_updates=True) self.process_msg(remote_msg_id, remote_sequence, reader)
except ReadCancelledError: except ReadCancelledError:
pass pass

View File

@ -10,6 +10,7 @@ from network import MtProtoSender, TcpTransport
from parser.markdown_parser import parse_message_entities from parser.markdown_parser import parse_message_entities
# For sending and receiving requests # For sending and receiving requests
from tl import MTProtoRequest
from tl import Session from tl import Session
from tl.types import PeerUser, PeerChat, PeerChannel, InputPeerUser, InputPeerChat, InputPeerChannel, InputPeerEmpty from tl.types import PeerUser, PeerChat, PeerChannel, InputPeerUser, InputPeerChat, InputPeerChannel, InputPeerEmpty
from tl.functions import InvokeWithLayerRequest, InitConnectionRequest from tl.functions import InvokeWithLayerRequest, InitConnectionRequest
@ -62,18 +63,22 @@ class TelegramClient:
# Now it's time to send an InitConnectionRequest # Now it's time to send an InitConnectionRequest
# This must always be invoked with the layer we'll be using # This must always be invoked with the layer we'll be using
request = InvokeWithLayerRequest(layer=self.layer,
query = InitConnectionRequest(api_id=self.api_id, query = InitConnectionRequest(api_id=self.api_id,
device_model=platform.node(), device_model=platform.node(),
system_version=platform.system(), system_version=platform.system(),
app_version='0.2', app_version='0.3',
lang_code='en', lang_code='en',
query=GetConfigRequest())) query=GetConfigRequest())
self.sender.send(request) result = self.invoke(InvokeWithLayerRequest(layer=self.layer, query=query))
self.sender.receive(request)
self.dc_options = request.result.dc_options # Only listen for updates if we're authorized
if self.is_user_authorized():
self.sender.set_listen_for_updates(True)
# We're only interested in the DC options,
# although many other options are available!
self.dc_options = result.dc_options
return True return True
except RPCError as error: except RPCError as error:
print('Could not stabilise initial connection: {}'.format(error)) print('Could not stabilise initial connection: {}'.format(error))
@ -114,11 +119,9 @@ class TelegramClient:
completed = False completed = False
while not completed: while not completed:
try: try:
self.sender.send(request) result = self.invoke(request)
self.sender.receive(request) self.phone_code_hashes[phone_number] = result.phone_code_hash
completed = True completed = True
if request.result:
self.phone_code_hashes[phone_number] = request.result.phone_code_hash
except InvalidDCError as error: except InvalidDCError as error:
self.reconnect_to_dc(error.new_dc) self.reconnect_to_dc(error.new_dc)
@ -137,19 +140,19 @@ class TelegramClient:
self.session.user = request.result.user self.session.user = request.result.user
self.session.save() self.session.save()
# Now that we're authorized, we can listen for incoming updates
self.sender.set_listen_for_updates(True)
return self.session.user return self.session.user
def get_dialogs(self, count=10, offset_date=None, offset_id=0, offset_peer=InputPeerEmpty()): def get_dialogs(self, count=10, offset_date=None, offset_id=0, offset_peer=InputPeerEmpty()):
"""Returns a tuple of lists ([dialogs], [displays], [input_peers]) with 'count' items each""" """Returns a tuple of lists ([dialogs], [displays], [input_peers]) with 'count' items each"""
request = GetDialogsRequest(offset_date=TelegramClient.get_tg_date(offset_date),
r = self.invoke(GetDialogsRequest(offset_date=TelegramClient.get_tg_date(offset_date),
offset_id=offset_id, offset_id=offset_id,
offset_peer=offset_peer, offset_peer=offset_peer,
limit=count) limit=count))
self.sender.send(request)
self.sender.receive(request)
r = request.result
return (r.dialogs, return (r.dialogs,
[self.find_display_name(d.peer, r.users, r.chats) for d in r.dialogs], [self.find_display_name(d.peer, r.users, r.chats) for d in r.dialogs],
[self.find_input_peer(d.peer, r.users, r.chats) for d in r.dialogs]) [self.find_input_peer(d.peer, r.users, r.chats) for d in r.dialogs])
@ -161,14 +164,11 @@ class TelegramClient:
else: else:
msg, entities = message, [] msg, entities = message, []
request = SendMessageRequest(peer=input_peer, self.invoke(SendMessageRequest(peer=input_peer,
message=msg, message=msg,
random_id=utils.generate_random_long(), random_id=utils.generate_random_long(),
entities=entities, entities=entities,
no_webpage=no_web_page) no_webpage=no_web_page))
self.sender.send(request)
self.sender.receive(request)
def get_message_history(self, input_peer, limit=20, def get_message_history(self, input_peer, limit=20,
offset_date=None, offset_id=0, max_id=0, min_id=0, add_offset=0): offset_date=None, offset_id=0, max_id=0, min_id=0, add_offset=0):
@ -186,18 +186,14 @@ class TelegramClient:
:return: A tuple containing total message count and two more lists ([messages], [senders]). :return: A tuple containing total message count and two more lists ([messages], [senders]).
Note that the sender can be null if it was not found! Note that the sender can be null if it was not found!
""" """
request = GetHistoryRequest(input_peer, result = self.invoke(GetHistoryRequest(input_peer,
limit=limit, limit=limit,
offset_date=self.get_tg_date(offset_date), offset_date=self.get_tg_date(offset_date),
offset_id=offset_id, offset_id=offset_id,
max_id=max_id, max_id=max_id,
min_id=min_id, min_id=min_id,
add_offset=add_offset) add_offset=add_offset))
self.sender.send(request)
self.sender.receive(request)
result = request.result
# The result may be a messages slice (not all messages were retrieved) or # The result may be a messages slice (not all messages were retrieved) or
# simply a messages TLObject. In the later case, no "count" attribute is specified: # simply a messages TLObject. In the later case, no "count" attribute is specified:
# the total messages count is retrieved by counting all the retrieved messages # the total messages count is retrieved by counting all the retrieved messages
@ -210,6 +206,16 @@ class TelegramClient:
for msg in result.messages # ...from all the messages... for msg in result.messages # ...from all the messages...
for usr in result.users]) # ...from all of the available users for usr in result.users]) # ...from all of the available users
def invoke(self, request):
"""Invokes an MTProtoRequest and returns its results"""
if not issubclass(type(request), MTProtoRequest):
raise ValueError('You can only invoke MtProtoRequests')
self.sender.send(request)
self.sender.receive(request)
return request.result
# endregion # endregion
# region Utilities # region Utilities