mirror of
https://github.com/LonamiWebs/Telethon.git
synced 2025-03-09 21:55:48 +03:00
Handle updates and other refactoring
This commit is contained in:
parent
004c92edbe
commit
984f483b98
|
@ -15,8 +15,8 @@ from ..tl import TLMessage, MessageContainer, GzipPacked
|
||||||
from ..tl.all_tlobjects import tlobjects
|
from ..tl.all_tlobjects import tlobjects
|
||||||
from ..tl.types import (
|
from ..tl.types import (
|
||||||
MsgsAck, Pong, BadServerSalt, BadMsgNotification,
|
MsgsAck, Pong, BadServerSalt, BadMsgNotification,
|
||||||
MsgNewDetailedInfo, NewSessionCreated, MsgDetailedInfo
|
MsgNewDetailedInfo, NewSessionCreated, MsgDetailedInfo,
|
||||||
)
|
RpcError)
|
||||||
from ..tl.functions.auth import LogOutRequest
|
from ..tl.functions.auth import LogOutRequest
|
||||||
|
|
||||||
logging.getLogger(__name__).addHandler(logging.NullHandler())
|
logging.getLogger(__name__).addHandler(logging.NullHandler())
|
||||||
|
@ -32,14 +32,17 @@ class MtProtoSender:
|
||||||
in parallel, so thread-safety (hence locking) isn't needed.
|
in parallel, so thread-safety (hence locking) isn't needed.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, session, connection, loop=None):
|
def __init__(self, session, connection, updates_handler, loop=None):
|
||||||
"""Creates a new MtProtoSender configured to send messages through
|
"""Creates a new MtProtoSender configured to send messages through
|
||||||
'connection' and using the parameters from 'session'.
|
'connection' and using the parameters from 'session'.
|
||||||
"""
|
"""
|
||||||
self.session = session
|
self.session = session
|
||||||
self.connection = connection
|
self.connection = connection
|
||||||
|
self.updates_handler = updates_handler
|
||||||
self._loop = loop if loop else asyncio.get_event_loop()
|
self._loop = loop if loop else asyncio.get_event_loop()
|
||||||
self._logger = logging.getLogger(__name__)
|
self._logger = logging.getLogger(__name__)
|
||||||
|
self._read_lock = asyncio.Lock(loop=self._loop)
|
||||||
|
self._write_lock = asyncio.Lock(loop=self._loop)
|
||||||
|
|
||||||
# Requests (as msg_id: Message) sent waiting to be received
|
# Requests (as msg_id: Message) sent waiting to be received
|
||||||
self._pending_receive = {}
|
self._pending_receive = {}
|
||||||
|
@ -56,10 +59,6 @@ class MtProtoSender:
|
||||||
self.connection.close()
|
self.connection.close()
|
||||||
self._clear_all_pending()
|
self._clear_all_pending()
|
||||||
|
|
||||||
def clone(self):
|
|
||||||
"""Creates a copy of this MtProtoSender as a new connection"""
|
|
||||||
return MtProtoSender(self.session, self.connection.clone(), self._loop)
|
|
||||||
|
|
||||||
# region Send and receive
|
# region Send and receive
|
||||||
|
|
||||||
async def send(self, *requests):
|
async def send(self, *requests):
|
||||||
|
@ -93,7 +92,7 @@ class MtProtoSender:
|
||||||
"""Sends a message acknowledge for the given msg_id"""
|
"""Sends a message acknowledge for the given msg_id"""
|
||||||
await self._send_message(TLMessage(self.session, MsgsAck([msg_id])))
|
await self._send_message(TLMessage(self.session, MsgsAck([msg_id])))
|
||||||
|
|
||||||
async def receive(self, update_state):
|
async def receive(self):
|
||||||
"""Receives a single message from the connected endpoint.
|
"""Receives a single message from the connected endpoint.
|
||||||
|
|
||||||
This method returns nothing, and will only affect other parts
|
This method returns nothing, and will only affect other parts
|
||||||
|
@ -103,6 +102,7 @@ class MtProtoSender:
|
||||||
Any unhandled object (likely updates) will be passed to
|
Any unhandled object (likely updates) will be passed to
|
||||||
update_state.process(TLObject).
|
update_state.process(TLObject).
|
||||||
"""
|
"""
|
||||||
|
await self._read_lock.acquire()
|
||||||
try:
|
try:
|
||||||
body = await self.connection.recv()
|
body = await self.connection.recv()
|
||||||
except (BufferError, InvalidChecksumError):
|
except (BufferError, InvalidChecksumError):
|
||||||
|
@ -115,10 +115,12 @@ class MtProtoSender:
|
||||||
# and just re-invoke them to avoid problems
|
# and just re-invoke them to avoid problems
|
||||||
self._clear_all_pending()
|
self._clear_all_pending()
|
||||||
return
|
return
|
||||||
|
finally:
|
||||||
|
self._read_lock.release()
|
||||||
|
|
||||||
message, remote_msg_id, remote_seq = self._decode_msg(body)
|
message, remote_msg_id, remote_seq = self._decode_msg(body)
|
||||||
with BinaryReader(message) as reader:
|
with BinaryReader(message) as reader:
|
||||||
await self._process_msg(remote_msg_id, remote_seq, reader, update_state)
|
await self._process_msg(remote_msg_id, remote_seq, reader)
|
||||||
await self._send_acknowledge(remote_msg_id)
|
await self._send_acknowledge(remote_msg_id)
|
||||||
|
|
||||||
# endregion
|
# endregion
|
||||||
|
@ -129,7 +131,7 @@ class MtProtoSender:
|
||||||
"""Sends the given Message(TLObject) encrypted through the network"""
|
"""Sends the given Message(TLObject) encrypted through the network"""
|
||||||
|
|
||||||
plain_text = \
|
plain_text = \
|
||||||
struct.pack('<QQ', self.session.salt, self.session.id) \
|
struct.pack('<qQ', self.session.salt, self.session.id) \
|
||||||
+ bytes(message)
|
+ bytes(message)
|
||||||
|
|
||||||
msg_key = utils.calc_msg_key(plain_text)
|
msg_key = utils.calc_msg_key(plain_text)
|
||||||
|
@ -138,7 +140,11 @@ class MtProtoSender:
|
||||||
cipher_text = AES.encrypt_ige(plain_text, key, iv)
|
cipher_text = AES.encrypt_ige(plain_text, key, iv)
|
||||||
|
|
||||||
result = key_id + msg_key + cipher_text
|
result = key_id + msg_key + cipher_text
|
||||||
|
await self._write_lock.acquire()
|
||||||
|
try:
|
||||||
await self.connection.send(result)
|
await self.connection.send(result)
|
||||||
|
finally:
|
||||||
|
self._write_lock.release()
|
||||||
|
|
||||||
def _decode_msg(self, body):
|
def _decode_msg(self, body):
|
||||||
"""Decodes an received encrypted message body bytes"""
|
"""Decodes an received encrypted message body bytes"""
|
||||||
|
@ -171,7 +177,7 @@ class MtProtoSender:
|
||||||
|
|
||||||
return message, remote_msg_id, remote_sequence
|
return message, remote_msg_id, remote_sequence
|
||||||
|
|
||||||
async def _process_msg(self, msg_id, sequence, reader, state):
|
async def _process_msg(self, msg_id, sequence, reader):
|
||||||
"""Processes and handles a Telegram message.
|
"""Processes and handles a Telegram message.
|
||||||
|
|
||||||
Returns True if the message was handled correctly and doesn't
|
Returns True if the message was handled correctly and doesn't
|
||||||
|
@ -191,10 +197,10 @@ class MtProtoSender:
|
||||||
return await self._handle_pong(msg_id, sequence, reader)
|
return await self._handle_pong(msg_id, sequence, reader)
|
||||||
|
|
||||||
if code == MessageContainer.CONSTRUCTOR_ID:
|
if code == MessageContainer.CONSTRUCTOR_ID:
|
||||||
return await self._handle_container(msg_id, sequence, reader, state)
|
return await self._handle_container(msg_id, sequence, reader)
|
||||||
|
|
||||||
if code == GzipPacked.CONSTRUCTOR_ID:
|
if code == GzipPacked.CONSTRUCTOR_ID:
|
||||||
return await self._handle_gzip_packed(msg_id, sequence, reader, state)
|
return await self._handle_gzip_packed(msg_id, sequence, reader)
|
||||||
|
|
||||||
if code == BadServerSalt.CONSTRUCTOR_ID:
|
if code == BadServerSalt.CONSTRUCTOR_ID:
|
||||||
return await self._handle_bad_server_salt(msg_id, sequence, reader)
|
return await self._handle_bad_server_salt(msg_id, sequence, reader)
|
||||||
|
@ -208,9 +214,6 @@ class MtProtoSender:
|
||||||
if code == MsgNewDetailedInfo.CONSTRUCTOR_ID:
|
if code == MsgNewDetailedInfo.CONSTRUCTOR_ID:
|
||||||
return await self._handle_msg_new_detailed_info(msg_id, sequence, reader)
|
return await self._handle_msg_new_detailed_info(msg_id, sequence, reader)
|
||||||
|
|
||||||
if code == NewSessionCreated.CONSTRUCTOR_ID:
|
|
||||||
return await self._handle_new_session_created(msg_id, sequence, reader)
|
|
||||||
|
|
||||||
if code == MsgsAck.CONSTRUCTOR_ID: # may handle the request we wanted
|
if code == MsgsAck.CONSTRUCTOR_ID: # may handle the request we wanted
|
||||||
ack = reader.tgread_object()
|
ack = reader.tgread_object()
|
||||||
assert isinstance(ack, MsgsAck)
|
assert isinstance(ack, MsgsAck)
|
||||||
|
@ -229,10 +232,7 @@ class MtProtoSender:
|
||||||
# If the code is not parsed manually then it should be a TLObject.
|
# If the code is not parsed manually then it should be a TLObject.
|
||||||
if code in tlobjects:
|
if code in tlobjects:
|
||||||
result = reader.tgread_object()
|
result = reader.tgread_object()
|
||||||
self.session.process_entities(result)
|
self.updates_handler(result)
|
||||||
if state:
|
|
||||||
state.process(result)
|
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
self._logger.debug(
|
self._logger.debug(
|
||||||
|
@ -307,7 +307,7 @@ class MtProtoSender:
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
async def _handle_container(self, msg_id, sequence, reader, state):
|
async def _handle_container(self, msg_id, sequence, reader):
|
||||||
self._logger.debug('Handling container')
|
self._logger.debug('Handling container')
|
||||||
for inner_msg_id, _, inner_len in MessageContainer.iter_read(reader):
|
for inner_msg_id, _, inner_len in MessageContainer.iter_read(reader):
|
||||||
begin_position = reader.tell_position()
|
begin_position = reader.tell_position()
|
||||||
|
@ -315,7 +315,7 @@ class MtProtoSender:
|
||||||
# Note that this code is IMPORTANT for skipping RPC results of
|
# Note that this code is IMPORTANT for skipping RPC results of
|
||||||
# lost requests (i.e., ones from the previous connection session)
|
# lost requests (i.e., ones from the previous connection session)
|
||||||
try:
|
try:
|
||||||
if not await self._process_msg(inner_msg_id, sequence, reader, state):
|
if not await self._process_msg(inner_msg_id, sequence, reader):
|
||||||
reader.set_position(begin_position + inner_len)
|
reader.set_position(begin_position + inner_len)
|
||||||
except:
|
except:
|
||||||
# If any error is raised, something went wrong; skip the packet
|
# If any error is raised, something went wrong; skip the packet
|
||||||
|
@ -330,9 +330,7 @@ class MtProtoSender:
|
||||||
assert isinstance(bad_salt, BadServerSalt)
|
assert isinstance(bad_salt, BadServerSalt)
|
||||||
|
|
||||||
# Our salt is unsigned, but the objects work with signed salts
|
# Our salt is unsigned, but the objects work with signed salts
|
||||||
self.session.salt = struct.unpack(
|
self.session.salt = bad_salt.new_server_salt
|
||||||
'<Q', struct.pack('<q', bad_salt.new_server_salt)
|
|
||||||
)[0]
|
|
||||||
self.session.save()
|
self.session.save()
|
||||||
|
|
||||||
# "the bad_server_salt response is received with the
|
# "the bad_server_salt response is received with the
|
||||||
|
@ -387,12 +385,6 @@ class MtProtoSender:
|
||||||
await self._send_acknowledge(msg_new.answer_msg_id)
|
await self._send_acknowledge(msg_new.answer_msg_id)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
async def _handle_new_session_created(self, msg_id, sequence, reader):
|
|
||||||
new_session = reader.tgread_object()
|
|
||||||
assert isinstance(new_session, NewSessionCreated)
|
|
||||||
# TODO https://goo.gl/LMyN7A
|
|
||||||
return True
|
|
||||||
|
|
||||||
async def _handle_rpc_result(self, msg_id, sequence, reader):
|
async def _handle_rpc_result(self, msg_id, sequence, reader):
|
||||||
self._logger.debug('Handling RPC result')
|
self._logger.debug('Handling RPC result')
|
||||||
reader.read_int(signed=False) # code
|
reader.read_int(signed=False) # code
|
||||||
|
@ -401,7 +393,7 @@ class MtProtoSender:
|
||||||
|
|
||||||
request = self._pop_request(request_id)
|
request = self._pop_request(request_id)
|
||||||
|
|
||||||
if inner_code == 0x2144ca19: # RPC Error
|
if inner_code == RpcError.CONSTRUCTOR_ID: # RPC Error
|
||||||
if self.session.report_errors and request:
|
if self.session.report_errors and request:
|
||||||
error = rpc_message_to_error(
|
error = rpc_message_to_error(
|
||||||
reader.read_int(), reader.tgread_string(),
|
reader.read_int(), reader.tgread_string(),
|
||||||
|
@ -422,7 +414,7 @@ class MtProtoSender:
|
||||||
|
|
||||||
elif request:
|
elif request:
|
||||||
self._logger.debug('Reading request response')
|
self._logger.debug('Reading request response')
|
||||||
if inner_code == 0x3072cfa1: # GZip packed
|
if inner_code == GzipPacked.CONSTRUCTOR_ID: # GZip packed
|
||||||
unpacked_data = gzip.decompress(reader.tgread_bytes())
|
unpacked_data = gzip.decompress(reader.tgread_bytes())
|
||||||
with BinaryReader(unpacked_data) as compressed_reader:
|
with BinaryReader(unpacked_data) as compressed_reader:
|
||||||
request.on_response(compressed_reader)
|
request.on_response(compressed_reader)
|
||||||
|
@ -439,9 +431,9 @@ class MtProtoSender:
|
||||||
self._logger.debug('Lost request will be skipped.')
|
self._logger.debug('Lost request will be skipped.')
|
||||||
return False
|
return False
|
||||||
|
|
||||||
async def _handle_gzip_packed(self, msg_id, sequence, reader, state):
|
async def _handle_gzip_packed(self, msg_id, sequence, reader):
|
||||||
self._logger.debug('Handling gzip packed data')
|
self._logger.debug('Handling gzip packed data')
|
||||||
with BinaryReader(GzipPacked.read(reader)) as compressed_reader:
|
with BinaryReader(GzipPacked.read(reader)) as compressed_reader:
|
||||||
return await self._process_msg(msg_id, sequence, compressed_reader, state)
|
return await self._process_msg(msg_id, sequence, compressed_reader)
|
||||||
|
|
||||||
# endregion
|
# endregion
|
||||||
|
|
|
@ -99,6 +99,7 @@ class TelegramBareClient:
|
||||||
self._sender = MtProtoSender(
|
self._sender = MtProtoSender(
|
||||||
self.session,
|
self.session,
|
||||||
Connection(mode=connection_mode, proxy=proxy, timeout=timeout, loop=self._loop),
|
Connection(mode=connection_mode, proxy=proxy, timeout=timeout, loop=self._loop),
|
||||||
|
self._updates_handler,
|
||||||
self._loop
|
self._loop
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -426,7 +427,7 @@ class TelegramBareClient:
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
while not all(x.confirm_received.is_set() for x in requests):
|
while not all(x.confirm_received.is_set() for x in requests):
|
||||||
await self._sender.receive(update_state=self.updates)
|
await self._sender.receive()
|
||||||
|
|
||||||
except BrokenAuthKeyError:
|
except BrokenAuthKeyError:
|
||||||
self._logger.error('Broken auth key, a new one will be generated')
|
self._logger.error('Broken auth key, a new one will be generated')
|
||||||
|
@ -705,6 +706,10 @@ class TelegramBareClient:
|
||||||
def list_update_handlers(self):
|
def list_update_handlers(self):
|
||||||
return self.updates.handlers[:]
|
return self.updates.handlers[:]
|
||||||
|
|
||||||
|
def _updates_handler(self, updates):
|
||||||
|
self.session.process_entities(updates)
|
||||||
|
self.updates.process(updates)
|
||||||
|
|
||||||
# endregion
|
# endregion
|
||||||
|
|
||||||
# Constant read
|
# Constant read
|
||||||
|
@ -732,7 +737,7 @@ class TelegramBareClient:
|
||||||
# Retry forever, this is instant messaging
|
# Retry forever, this is instant messaging
|
||||||
await asyncio.sleep(0.1, loop=self._loop)
|
await asyncio.sleep(0.1, loop=self._loop)
|
||||||
|
|
||||||
await self._sender.receive(update_state=self.updates)
|
await self._sender.receive()
|
||||||
except TimeoutError:
|
except TimeoutError:
|
||||||
# No problem.
|
# No problem.
|
||||||
pass
|
pass
|
||||||
|
|
|
@ -37,6 +37,7 @@ class Session:
|
||||||
self.report_errors = session.report_errors
|
self.report_errors = session.report_errors
|
||||||
self.save_entities = session.save_entities
|
self.save_entities = session.save_entities
|
||||||
self.flood_sleep_threshold = session.flood_sleep_threshold
|
self.flood_sleep_threshold = session.flood_sleep_threshold
|
||||||
|
self.user_id = session.user_id
|
||||||
|
|
||||||
else: # str / None
|
else: # str / None
|
||||||
self.session_user_id = session_user_id
|
self.session_user_id = session_user_id
|
||||||
|
@ -51,6 +52,7 @@ class Session:
|
||||||
self.report_errors = True
|
self.report_errors = True
|
||||||
self.save_entities = True
|
self.save_entities = True
|
||||||
self.flood_sleep_threshold = 60
|
self.flood_sleep_threshold = 60
|
||||||
|
self.user_id = 0
|
||||||
|
|
||||||
self.id = helpers.generate_random_long(signed=False)
|
self.id = helpers.generate_random_long(signed=False)
|
||||||
self._sequence = 0
|
self._sequence = 0
|
||||||
|
@ -78,7 +80,8 @@ class Session:
|
||||||
'server_address': self.server_address,
|
'server_address': self.server_address,
|
||||||
'auth_key_data':
|
'auth_key_data':
|
||||||
b64encode(self.auth_key.key).decode('ascii')
|
b64encode(self.auth_key.key).decode('ascii')
|
||||||
if self.auth_key else None
|
if self.auth_key else None,
|
||||||
|
'user_id': self.user_id
|
||||||
}
|
}
|
||||||
if self.save_entities:
|
if self.save_entities:
|
||||||
out_dict['entities'] = self.entities.get_input_list()
|
out_dict['entities'] = self.entities.get_input_list()
|
||||||
|
@ -122,6 +125,7 @@ class Session:
|
||||||
result.layer = data.get('layer', result.layer)
|
result.layer = data.get('layer', result.layer)
|
||||||
result.server_address = \
|
result.server_address = \
|
||||||
data.get('server_address', result.server_address)
|
data.get('server_address', result.server_address)
|
||||||
|
result.user_id = data.get('user_id', result.user_id)
|
||||||
|
|
||||||
# FIXME We need to import the AuthKey here or otherwise
|
# FIXME We need to import the AuthKey here or otherwise
|
||||||
# we get cyclic dependencies.
|
# we get cyclic dependencies.
|
||||||
|
|
Loading…
Reference in New Issue
Block a user