Handle updates and other refactoring

This commit is contained in:
Andrey Egorov 2017-12-03 02:11:50 +03:00
parent 004c92edbe
commit 984f483b98
3 changed files with 40 additions and 39 deletions

View File

@ -15,8 +15,8 @@ from ..tl import TLMessage, MessageContainer, GzipPacked
from ..tl.all_tlobjects import tlobjects
from ..tl.types import (
MsgsAck, Pong, BadServerSalt, BadMsgNotification,
MsgNewDetailedInfo, NewSessionCreated, MsgDetailedInfo
)
MsgNewDetailedInfo, NewSessionCreated, MsgDetailedInfo,
RpcError)
from ..tl.functions.auth import LogOutRequest
logging.getLogger(__name__).addHandler(logging.NullHandler())
@ -32,14 +32,17 @@ class MtProtoSender:
in parallel, so thread-safety (hence locking) isn't needed.
"""
def __init__(self, session, connection, loop=None):
def __init__(self, session, connection, updates_handler, loop=None):
"""Creates a new MtProtoSender configured to send messages through
'connection' and using the parameters from 'session'.
"""
self.session = session
self.connection = connection
self.updates_handler = updates_handler
self._loop = loop if loop else asyncio.get_event_loop()
self._logger = logging.getLogger(__name__)
self._read_lock = asyncio.Lock(loop=self._loop)
self._write_lock = asyncio.Lock(loop=self._loop)
# Requests (as msg_id: Message) sent waiting to be received
self._pending_receive = {}
@ -56,10 +59,6 @@ class MtProtoSender:
self.connection.close()
self._clear_all_pending()
def clone(self):
"""Creates a copy of this MtProtoSender as a new connection"""
return MtProtoSender(self.session, self.connection.clone(), self._loop)
# region Send and receive
async def send(self, *requests):
@ -93,7 +92,7 @@ class MtProtoSender:
"""Sends a message acknowledge for the given msg_id"""
await self._send_message(TLMessage(self.session, MsgsAck([msg_id])))
async def receive(self, update_state):
async def receive(self):
"""Receives a single message from the connected endpoint.
This method returns nothing, and will only affect other parts
@ -103,6 +102,7 @@ class MtProtoSender:
Any unhandled object (likely updates) will be passed to
update_state.process(TLObject).
"""
await self._read_lock.acquire()
try:
body = await self.connection.recv()
except (BufferError, InvalidChecksumError):
@ -115,10 +115,12 @@ class MtProtoSender:
# and just re-invoke them to avoid problems
self._clear_all_pending()
return
finally:
self._read_lock.release()
message, remote_msg_id, remote_seq = self._decode_msg(body)
with BinaryReader(message) as reader:
await self._process_msg(remote_msg_id, remote_seq, reader, update_state)
await self._process_msg(remote_msg_id, remote_seq, reader)
await self._send_acknowledge(remote_msg_id)
# endregion
@ -129,7 +131,7 @@ class MtProtoSender:
"""Sends the given Message(TLObject) encrypted through the network"""
plain_text = \
struct.pack('<QQ', self.session.salt, self.session.id) \
struct.pack('<qQ', self.session.salt, self.session.id) \
+ bytes(message)
msg_key = utils.calc_msg_key(plain_text)
@ -138,7 +140,11 @@ class MtProtoSender:
cipher_text = AES.encrypt_ige(plain_text, key, iv)
result = key_id + msg_key + cipher_text
await self.connection.send(result)
await self._write_lock.acquire()
try:
await self.connection.send(result)
finally:
self._write_lock.release()
def _decode_msg(self, body):
"""Decodes an received encrypted message body bytes"""
@ -171,7 +177,7 @@ class MtProtoSender:
return message, remote_msg_id, remote_sequence
async def _process_msg(self, msg_id, sequence, reader, state):
async def _process_msg(self, msg_id, sequence, reader):
"""Processes and handles a Telegram message.
Returns True if the message was handled correctly and doesn't
@ -191,10 +197,10 @@ class MtProtoSender:
return await self._handle_pong(msg_id, sequence, reader)
if code == MessageContainer.CONSTRUCTOR_ID:
return await self._handle_container(msg_id, sequence, reader, state)
return await self._handle_container(msg_id, sequence, reader)
if code == GzipPacked.CONSTRUCTOR_ID:
return await self._handle_gzip_packed(msg_id, sequence, reader, state)
return await self._handle_gzip_packed(msg_id, sequence, reader)
if code == BadServerSalt.CONSTRUCTOR_ID:
return await self._handle_bad_server_salt(msg_id, sequence, reader)
@ -208,9 +214,6 @@ class MtProtoSender:
if code == MsgNewDetailedInfo.CONSTRUCTOR_ID:
return await self._handle_msg_new_detailed_info(msg_id, sequence, reader)
if code == NewSessionCreated.CONSTRUCTOR_ID:
return await self._handle_new_session_created(msg_id, sequence, reader)
if code == MsgsAck.CONSTRUCTOR_ID: # may handle the request we wanted
ack = reader.tgread_object()
assert isinstance(ack, MsgsAck)
@ -229,10 +232,7 @@ class MtProtoSender:
# If the code is not parsed manually then it should be a TLObject.
if code in tlobjects:
result = reader.tgread_object()
self.session.process_entities(result)
if state:
state.process(result)
self.updates_handler(result)
return True
self._logger.debug(
@ -307,7 +307,7 @@ class MtProtoSender:
return True
async def _handle_container(self, msg_id, sequence, reader, state):
async def _handle_container(self, msg_id, sequence, reader):
self._logger.debug('Handling container')
for inner_msg_id, _, inner_len in MessageContainer.iter_read(reader):
begin_position = reader.tell_position()
@ -315,7 +315,7 @@ class MtProtoSender:
# Note that this code is IMPORTANT for skipping RPC results of
# lost requests (i.e., ones from the previous connection session)
try:
if not await self._process_msg(inner_msg_id, sequence, reader, state):
if not await self._process_msg(inner_msg_id, sequence, reader):
reader.set_position(begin_position + inner_len)
except:
# If any error is raised, something went wrong; skip the packet
@ -330,9 +330,7 @@ class MtProtoSender:
assert isinstance(bad_salt, BadServerSalt)
# Our salt is unsigned, but the objects work with signed salts
self.session.salt = struct.unpack(
'<Q', struct.pack('<q', bad_salt.new_server_salt)
)[0]
self.session.salt = bad_salt.new_server_salt
self.session.save()
# "the bad_server_salt response is received with the
@ -387,12 +385,6 @@ class MtProtoSender:
await self._send_acknowledge(msg_new.answer_msg_id)
return True
async def _handle_new_session_created(self, msg_id, sequence, reader):
new_session = reader.tgread_object()
assert isinstance(new_session, NewSessionCreated)
# TODO https://goo.gl/LMyN7A
return True
async def _handle_rpc_result(self, msg_id, sequence, reader):
self._logger.debug('Handling RPC result')
reader.read_int(signed=False) # code
@ -401,7 +393,7 @@ class MtProtoSender:
request = self._pop_request(request_id)
if inner_code == 0x2144ca19: # RPC Error
if inner_code == RpcError.CONSTRUCTOR_ID: # RPC Error
if self.session.report_errors and request:
error = rpc_message_to_error(
reader.read_int(), reader.tgread_string(),
@ -422,7 +414,7 @@ class MtProtoSender:
elif request:
self._logger.debug('Reading request response')
if inner_code == 0x3072cfa1: # GZip packed
if inner_code == GzipPacked.CONSTRUCTOR_ID: # GZip packed
unpacked_data = gzip.decompress(reader.tgread_bytes())
with BinaryReader(unpacked_data) as compressed_reader:
request.on_response(compressed_reader)
@ -439,9 +431,9 @@ class MtProtoSender:
self._logger.debug('Lost request will be skipped.')
return False
async def _handle_gzip_packed(self, msg_id, sequence, reader, state):
async def _handle_gzip_packed(self, msg_id, sequence, reader):
self._logger.debug('Handling gzip packed data')
with BinaryReader(GzipPacked.read(reader)) as compressed_reader:
return await self._process_msg(msg_id, sequence, compressed_reader, state)
return await self._process_msg(msg_id, sequence, compressed_reader)
# endregion

View File

@ -99,6 +99,7 @@ class TelegramBareClient:
self._sender = MtProtoSender(
self.session,
Connection(mode=connection_mode, proxy=proxy, timeout=timeout, loop=self._loop),
self._updates_handler,
self._loop
)
@ -426,7 +427,7 @@ class TelegramBareClient:
)
else:
while not all(x.confirm_received.is_set() for x in requests):
await self._sender.receive(update_state=self.updates)
await self._sender.receive()
except BrokenAuthKeyError:
self._logger.error('Broken auth key, a new one will be generated')
@ -705,6 +706,10 @@ class TelegramBareClient:
def list_update_handlers(self):
return self.updates.handlers[:]
def _updates_handler(self, updates):
self.session.process_entities(updates)
self.updates.process(updates)
# endregion
# Constant read
@ -732,7 +737,7 @@ class TelegramBareClient:
# Retry forever, this is instant messaging
await asyncio.sleep(0.1, loop=self._loop)
await self._sender.receive(update_state=self.updates)
await self._sender.receive()
except TimeoutError:
# No problem.
pass

View File

@ -37,6 +37,7 @@ class Session:
self.report_errors = session.report_errors
self.save_entities = session.save_entities
self.flood_sleep_threshold = session.flood_sleep_threshold
self.user_id = session.user_id
else: # str / None
self.session_user_id = session_user_id
@ -51,6 +52,7 @@ class Session:
self.report_errors = True
self.save_entities = True
self.flood_sleep_threshold = 60
self.user_id = 0
self.id = helpers.generate_random_long(signed=False)
self._sequence = 0
@ -78,7 +80,8 @@ class Session:
'server_address': self.server_address,
'auth_key_data':
b64encode(self.auth_key.key).decode('ascii')
if self.auth_key else None
if self.auth_key else None,
'user_id': self.user_id
}
if self.save_entities:
out_dict['entities'] = self.entities.get_input_list()
@ -122,6 +125,7 @@ class Session:
result.layer = data.get('layer', result.layer)
result.server_address = \
data.get('server_address', result.server_address)
result.user_id = data.get('user_id', result.user_id)
# FIXME We need to import the AuthKey here or otherwise
# we get cyclic dependencies.