Change the way in which updates are read and handled

This commit is contained in:
Lonami Exo 2017-05-29 20:41:03 +02:00
parent ae1dbc63da
commit ebe4232b32
2 changed files with 68 additions and 35 deletions

View File

@ -138,23 +138,41 @@ class MtProtoSender:
self._logger.debug('send() released the lock') self._logger.debug('send() released the lock')
def receive(self, request, timeout=timedelta(seconds=5)): def receive(self, request=None, timeout=timedelta(seconds=5), updates=None):
"""Receives the specified MTProtoRequest ("fills in it" """Receives the specified MTProtoRequest ("fills in it"
the received data). This also restores the updates thread. the received data). This also restores the updates thread.
An optional timeout can be specified to cancel the operation An optional timeout can be specified to cancel the operation
if no data has been read after its time delta""" if no data has been read after its time delta.
If 'request' is None, a single item will be read into
the 'updates' list (which cannot be None).
If 'request' is not None, any update received before
reading the request's result will be put there unless
it's None, in which case updates will be ignored.
"""
if request is None and updates is None:
raise ValueError('Both the "request" and "updates"'
'parameters cannot be None at the same time.')
with self._lock: with self._lock:
self._logger.debug('receive() acquired the lock') self._logger.debug('receive() acquired the lock')
# Don't stop trying to receive until we get the request we wanted # Don't stop trying to receive until we get the request we wanted
while not request.confirm_received: # or, if there is no request, until we read an update
while True:
self._logger.info('Trying to .receive() the request result...') self._logger.info('Trying to .receive() the request result...')
seq, body = self._transport.receive(timeout) seq, body = self._transport.receive(timeout)
message, remote_msg_id, remote_sequence = self._decode_msg(body) message, remote_msg_id, remote_sequence = self._decode_msg(body)
with BinaryReader(message) as reader: with BinaryReader(message) as reader:
self._process_msg(remote_msg_id, remote_sequence, reader, self._process_msg(remote_msg_id, remote_sequence, reader,
request) request, updates)
if request is None:
if updates:
break # No request but one update read, exit
elif request.confirm_received:
break # Request, and result read, exit
self._logger.info('Request result received') self._logger.info('Request result received')
@ -162,6 +180,12 @@ class MtProtoSender:
self._waiting_receive.clear() self._waiting_receive.clear()
self._logger.debug('receive() released the lock') self._logger.debug('receive() released the lock')
def receive_update(self, timeout=timedelta(seconds=5)):
"""Receives an update object and returns its result"""
updates = []
self.receive(timeout=timeout, updates=updates)
return updates[0]
# endregion # endregion
# region Low level processing # region Low level processing
@ -221,7 +245,7 @@ class MtProtoSender:
return message, remote_msg_id, remote_sequence return message, remote_msg_id, remote_sequence
def _process_msg(self, msg_id, sequence, reader, request=None): def _process_msg(self, msg_id, sequence, reader, request, updates):
"""Processes and handles a Telegram message""" """Processes and handles a Telegram message"""
# TODO Check salt, session_id and sequence_number # TODO Check salt, session_id and sequence_number
@ -232,19 +256,28 @@ class MtProtoSender:
# The following codes are "parsed manually" # The following codes are "parsed manually"
if code == 0xf35c6d01: # rpc_result, (response of an RPC call, i.e., we sent a request) if code == 0xf35c6d01: # rpc_result, (response of an RPC call, i.e., we sent a request)
return self._handle_rpc_result(msg_id, sequence, reader, request) return self._handle_rpc_result(
msg_id, sequence, reader, request)
if code == 0x347773c5: # pong if code == 0x347773c5: # pong
return self._handle_pong(msg_id, sequence, reader, request) return self._handle_pong(
msg_id, sequence, reader, request)
if code == 0x73f1f8dc: # msg_container if code == 0x73f1f8dc: # msg_container
return self._handle_container(msg_id, sequence, reader, request) return self._handle_container(
msg_id, sequence, reader, request, updates)
if code == 0x3072cfa1: # gzip_packed if code == 0x3072cfa1: # gzip_packed
return self._handle_gzip_packed(msg_id, sequence, reader, request) return self._handle_gzip_packed(
msg_id, sequence, reader, request, updates)
if code == 0xedab447b: # bad_server_salt if code == 0xedab447b: # bad_server_salt
return self._handle_bad_server_salt(msg_id, sequence, reader, return self._handle_bad_server_salt(
request) msg_id, sequence, reader, request)
if code == 0xa7eff811: # bad_msg_notification if code == 0xa7eff811: # bad_msg_notification
return self._handle_bad_msg_notification(msg_id, sequence, reader) return self._handle_bad_msg_notification(
msg_id, sequence, reader)
# msgs_ack, it may handle the request we wanted # msgs_ack, it may handle the request we wanted
if code == 0x62d6b459: if code == 0x62d6b459:
@ -262,7 +295,14 @@ class MtProtoSender:
# In this case, we will simply treat the incoming TLObject as an Update, # In this case, we will simply treat the incoming TLObject as an Update,
# if we can first find a matching TLObject # if we can first find a matching TLObject
if code in tlobjects.keys(): if code in tlobjects.keys():
return self._handle_update(msg_id, sequence, reader) result = reader.tgread_object()
if updates is None:
self._logger.debug('Ignored update for %s', repr(result))
else:
self._logger.debug('Read update for %s', repr(result))
updates.append(result)
return False
print('Unknown message: {}'.format(hex(code))) print('Unknown message: {}'.format(hex(code)))
return False return False
@ -271,14 +311,6 @@ class MtProtoSender:
# region Message handling # region Message handling
def _handle_update(self, msg_id, sequence, reader):
tlobject = reader.tgread_object()
self._logger.debug('Handling update for object %s', repr(tlobject))
for handler in self._on_update_handlers:
handler(tlobject)
return False
def _handle_pong(self, msg_id, sequence, reader, request): def _handle_pong(self, msg_id, sequence, reader, request):
self._logger.debug('Handling pong') self._logger.debug('Handling pong')
reader.read_int(signed=False) # code reader.read_int(signed=False) # code
@ -290,7 +322,7 @@ class MtProtoSender:
return False return False
def _handle_container(self, msg_id, sequence, reader, request): def _handle_container(self, msg_id, sequence, reader, request, updates):
self._logger.debug('Handling container') self._logger.debug('Handling container')
reader.read_int(signed=False) # code reader.read_int(signed=False) # code
size = reader.read_int() size = reader.read_int()
@ -300,9 +332,10 @@ class MtProtoSender:
inner_length = reader.read_int() inner_length = reader.read_int()
begin_position = reader.tell_position() begin_position = reader.tell_position()
# note: this code is IMPORTANT for skipping RPC results of lost # Note that this code is IMPORTANT for skipping RPC results of
# requests (for example, ones from the previous connection session) # lost requests (i.e., ones from the previous connection session)
if not self._process_msg(inner_msg_id, sequence, reader, request): if not self._process_msg(
inner_msg_id, sequence, reader, request, updates):
reader.set_position(begin_position + inner_length) reader.set_position(begin_position + inner_length)
return False return False
@ -393,7 +426,7 @@ class MtProtoSender:
# session, it will be skipped by the handle_container() # session, it will be skipped by the handle_container()
self._logger.warning('RPC result found for unknown request (maybe from previous connection session)') self._logger.warning('RPC result found for unknown request (maybe from previous connection session)')
def _handle_gzip_packed(self, msg_id, sequence, reader, request): def _handle_gzip_packed(self, msg_id, sequence, reader, request, updates):
self._logger.debug('Handling gzip packed data') self._logger.debug('Handling gzip packed data')
reader.read_int(signed=False) # code reader.read_int(signed=False) # code
packed_data = reader.tgread_bytes() packed_data = reader.tgread_bytes()
@ -401,7 +434,7 @@ class MtProtoSender:
with BinaryReader(unpacked_data) as compressed_reader: with BinaryReader(unpacked_data) as compressed_reader:
return self._process_msg(msg_id, sequence, compressed_reader, return self._process_msg(msg_id, sequence, compressed_reader,
request) request, updates)
# endregion # endregion
@ -457,14 +490,10 @@ class MtProtoSender:
self._updates_thread_receiving.set() self._updates_thread_receiving.set()
self._logger.debug('Trying to receive updates from the updates thread') self._logger.debug('Trying to receive updates from the updates thread')
seq, body = self._transport.receive(timeout) result = self.receive_update(timeout=timeout)
message, remote_msg_id, remote_sequence = self._decode_msg(
body)
self._logger.info('Received update from the updates thread') self._logger.info('Received update from the updates thread')
with BinaryReader(message) as reader: for handler in self._on_update_handlers:
self._process_msg(remote_msg_id, remote_sequence, handler(result)
reader)
except TimeoutError: except TimeoutError:
self._logger.debug('Receiving updates timed out') self._logger.debug('Receiving updates timed out')

View File

@ -169,8 +169,12 @@ class TelegramClient:
raise ValueError('You must be connected to invoke requests!') raise ValueError('You must be connected to invoke requests!')
try: try:
updates = []
self.sender.send(request) self.sender.send(request)
self.sender.receive(request, timeout) self.sender.receive(request, timeout, updates=updates)
for update in updates:
for handler in self.sender._on_update_handlers:
handler(update)
return request.result return request.result