Use new broken MessagePacker

This commit is contained in:
Lonami Exo 2018-10-19 13:24:52 +02:00
parent 83f60deef9
commit e2fe3eb503
6 changed files with 162 additions and 228 deletions

View File

@ -30,12 +30,18 @@ class AuthKey:
self._key = self.aux_hash = self.key_id = None
return
if isinstance(value, type(self)):
self._key, self.aux_hash, self.key_id = \
value._key, value.aux_hash, value.key_id
return
self._key = value
with BinaryReader(sha1(self._key).digest()) as reader:
self.aux_hash = reader.read_long(signed=False)
reader.read(4)
self.key_id = reader.read_long(signed=False)
# TODO This doesn't really fit here, it's only used in authentication
def calc_new_nonce_hash(self, new_nonce, number):
"""
Calculates the new nonce hash based on the current attributes.

View File

@ -0,0 +1,116 @@
import asyncio
import collections
import io
import logging
import struct
from ..tl import TLRequest
from ..tl.core.messagecontainer import MessageContainer
from ..tl.core.tlmessage import TLMessage
__log__ = logging.getLogger(__name__)
class MessagePacker:
"""
This class packs `RequestState` as outgoing `TLMessages`.
The purpose of this class is to support putting N `RequestState` into a
queue, and then awaiting for "packed" `TLMessage` in the other end. The
simplest case would be ``State -> TLMessage`` (1-to-1 relationship) but
for efficiency purposes it's ``States -> Container`` (N-to-1).
This addresses several needs: outgoing messages will be smaller, so the
encryption and network overhead also is smaller. It's also a central
point where outgoing requests are put, and where ready-messages are get.
"""
def __init__(self, state, loop):
self._state = state
self._loop = loop
self._deque = collections.deque()
self._ready = asyncio.Event(loop=loop)
def append(self, state):
self._deque.append(state)
self._ready.set()
def extend(self, states):
self._deque.extend(states)
self._ready.set()
async def get(self, cancellation):
"""
Returns (batch, data) if one or more items could be retrieved.
If the cancellation occurs or only invalid items were in the
queue, (None, None) will be returned instead.
"""
if not self._deque:
self._ready.clear()
ready = self._loop.create_task(self._ready.wait())
try:
done, pending = await asyncio.wait(
[ready, cancellation],
return_when=asyncio.FIRST_COMPLETED,
loop=self._loop
)
except asyncio.CancelledError:
done = [cancellation]
if cancellation in done:
ready.cancel()
return None, None
buffer = io.BytesIO()
batch = []
size = 0
# Fill a new batch to return while the size is small enough
while self._deque:
state = self._deque.popleft()
size += len(state.data) + TLMessage.SIZE_OVERHEAD
if size <= MessageContainer.MAXIMUM_SIZE:
# TODO Implement back using after_id
state.msg_id = self._state.write_data_as_message(
buffer, state.data, isinstance(state.request, TLRequest)
)
batch.append(state)
__log__.debug('Assigned msg_id = %d to %s (%x)',
state.msg_id, state.request.__class__.__name__,
id(state.request))
continue
# Put the item back since it can't be sent in this batch
self._deque.appendleft(state)
if batch:
break
# If a single message exceeds the maximum size, then the
# message payload cannot be sent. Telegram would forcibly
# close the connection; message would never be confirmed.
state.future.set_exception(
ValueError('Request payload is too big'))
size = 0
continue
if not batch:
return None, None
if len(batch) > 1:
# Inlined code to pack several messages into a container
data = struct.pack(
'<Ii', MessageContainer.CONSTRUCTOR_ID, len(batch)
) + buffer.getvalue()
buffer = io.BytesIO()
container_id = self._state.write_data_as_message(
buffer, data, content_related=False
)
for s in batch:
s.container_id = container_id
data = buffer.getvalue()
__log__.debug('Packed %d message(s) in %d bytes for sending',
len(batch), len(data))
return batch, data

View File

@ -69,6 +69,7 @@ def get_password_hash(pw, current_salt):
# region Custom Classes
class TotalList(list):
"""
A list with an extra `total` property, which may not match its `len`
@ -88,45 +89,4 @@ class TotalList(list):
', '.join(repr(x) for x in self), self.total)
class _ReadyQueue:
"""
A queue list that supports an arbitrary cancellation token for `get`.
"""
def __init__(self, loop):
self._list = []
self._loop = loop
self._ready = asyncio.Event(loop=loop)
def append(self, item):
self._list.append(item)
self._ready.set()
def extend(self, items):
self._list.extend(items)
self._ready.set()
async def get(self, cancellation):
"""
Returns a list of all the items added to the queue until now and
clears the list from the queue itself. Returns ``None`` if cancelled.
"""
ready = self._loop.create_task(self._ready.wait())
try:
done, pending = await asyncio.wait(
[ready, cancellation],
return_when=asyncio.FIRST_COMPLETED,
loop=self._loop
)
except asyncio.CancelledError:
done = [cancellation]
if cancellation in done:
ready.cancel()
return None
result = self._list
self._list = []
self._ready.clear()
return result
# endregion

View File

@ -1,158 +0,0 @@
import io
import logging
import struct
from .mtprotostate import MTProtoState
from ..tl import TLRequest
from ..tl.core.tlmessage import TLMessage
from ..tl.core.messagecontainer import MessageContainer
__log__ = logging.getLogger(__name__)
class MTProtoLayer:
"""
This class is the message encryption layer between the methods defined
in the schema and the response objects. It also holds the necessary state
necessary for this encryption to happen.
The `connection` parameter is through which these messages will be sent
and received.
The `auth_key` must be a valid authorization key which will be used to
encrypt these messages. This class is not responsible for generating them.
"""
def __init__(self, connection, auth_key):
self._connection = connection
self._state = MTProtoState(auth_key)
def connect(self, timeout=None):
"""
Wrapper for ``self._connection.connect()``.
"""
return self._connection.connect(timeout=timeout)
def disconnect(self):
"""
Wrapper for ``self._connection.disconnect()``.
"""
self._connection.disconnect()
def reset_state(self):
self._state = MTProtoState(self._state.auth_key)
async def send(self, state_list):
"""
The list of `RequestState` that will be sent. They will
be updated with their new message and container IDs.
Nested lists imply an order is required for the messages in them.
Message containers will be used if there is more than one item.
"""
for data in filter(None, self._pack_state_list(state_list)):
await self._connection.send(self._state.encrypt_message_data(data))
async def recv(self):
"""
Reads a single message from the network, decrypts it and returns it.
"""
body = await self._connection.recv()
return self._state.decrypt_message_data(body)
def _pack_state_list(self, state_list):
"""
The list of `RequestState` that will be sent. They will
be updated with their new message and container IDs.
Packs all their serialized data into a message (possibly
nested inside another message and message container) and
returns the serialized message data.
"""
# Note that the simplest case is writing a single query data into
# a message, and returning the message data and ID. For efficiency
# purposes this method supports more than one message and automatically
# uses containers if deemed necessary.
#
# Technically the message and message container classes could be used
# to store and serialize the data. However, to keep the context local
# and relevant to the only place where such feature is actually used,
# this is not done.
#
# When iterating over the state_list there are two branches, one
# being just a state and the other being a list so the inner states
# depend on each other. In either case, if the packed size exceeds
# the maximum container size, it must be sent. This code is non-
# trivial so it has been factored into an inner function.
#
# A new buffer instance will be used once the size should be "flushed"
buffer = io.BytesIO()
# The batch of requests sent in a single buffer-flush. We need to
# remember which states were written to set their container ID.
batch = []
# The currently written size. Reset when it exceeds the maximum.
size = 0
def write_state(state, after_id=None):
nonlocal buffer, batch, size
if state:
batch.append(state)
size += len(state.data) + TLMessage.SIZE_OVERHEAD
# Flush whenever the current size exceeds the maximum,
# or if there's no state, which indicates force flush.
if not state or size > MessageContainer.MAXIMUM_SIZE:
size -= MessageContainer.MAXIMUM_SIZE
if len(batch) > 1:
# Inlined code to pack several messages into a container
data = struct.pack(
'<Ii', MessageContainer.CONSTRUCTOR_ID, len(batch)
) + buffer.getvalue()
buffer = io.BytesIO()
container_id = self._state.write_data_as_message(
buffer, data, content_related=False
)
for s in batch:
s.container_id = container_id
# At this point it's either a single msg or a msg + container
data = buffer.getvalue()
__log__.debug('Packed %d message(s) in %d bytes for sending',
len(batch), len(data))
batch.clear()
buffer = io.BytesIO()
return data
if not state:
return # Just forcibly flushing
# If even after flushing it still exceeds the maximum size,
# this message payload cannot be sent. Telegram would forcibly
# close the connection, and the message would never be confirmed.
if size > MessageContainer.MAXIMUM_SIZE:
state.future.set_exception(
ValueError('Request payload is too big'))
return
# This is the only requirement to make this work.
state.msg_id = self._state.write_data_as_message(
buffer, state.data, isinstance(state.request, TLRequest),
after_id=after_id
)
__log__.debug('Assigned msg_id = %d to %s (%x)',
state.msg_id, state.request.__class__.__name__,
id(state.request))
# TODO Yield in the inner loop -> Telegram "Invalid container". Why?
for state in state_list:
if not isinstance(state, list):
yield write_state(state)
else:
after_id = None
for s in state:
yield write_state(s, after_id)
after_id = s.msg_id
yield write_state(None)
def __str__(self):
return str(self._connection)

View File

@ -3,9 +3,10 @@ import collections
import logging
from . import authenticator
from .mtprotolayer import MTProtoLayer
from ..extensions.messagepacker import MessagePacker
from .mtprotoplainsender import MTProtoPlainSender
from .requeststate import RequestState
from .mtprotostate import MTProtoState
from ..tl.tlobject import TLRequest
from .. import utils
from ..errors import (
@ -13,7 +14,6 @@ from ..errors import (
InvalidChecksumError, rpc_message_to_error
)
from ..extensions import BinaryReader
from ..helpers import _ReadyQueue
from ..tl.core import RpcResult, MessageContainer, GzipPacked
from ..tl.functions.auth import LogOutRequest
from ..tl.types import (
@ -21,7 +21,7 @@ from ..tl.types import (
MsgNewDetailedInfo, NewSessionCreated, MsgDetailedInfo, MsgsStateReq,
MsgsStateInfo, MsgsAllInfo, MsgResendReq, upload
)
from ..utils import AsyncClassWrapper
from ..crypto import AuthKey
__log__ = logging.getLogger(__name__)
@ -43,9 +43,9 @@ class MTProtoSender:
"""
def __init__(self, loop, *,
retries=5, auto_reconnect=True, connect_timeout=None,
update_callback=None,
update_callback=None, auth_key=None,
auth_key_callback=None, auto_reconnect_callback=None):
self._connection = None # MTProtoLayer, a.k.a. encrypted connection
self._connection = None
self._loop = loop
self._retries = retries
self._auto_reconnect = auto_reconnect
@ -68,10 +68,13 @@ class MTProtoSender:
self._send_loop_handle = None
self._recv_loop_handle = None
# Preserving the references of the AuthKey and state is important
self._auth_key = auth_key or AuthKey(None)
self._state = MTProtoState(self._auth_key)
# Outgoing messages are put in a queue and sent in a batch.
# Note that here we're also storing their ``_RequestState``.
# Note that it may also store lists (implying order must be kept).
self._send_queue = _ReadyQueue(self._loop)
self._send_queue = MessagePacker(self._state, self._loop)
# Sent states are remembered until a response is received.
self._pending_state = {}
@ -112,7 +115,7 @@ class MTProtoSender:
__log__.info('User is already connected!')
return
self._connection = MTProtoLayer(connection, auth_key)
self._connection = connection
self._user_connected = True
await self._connect()
@ -204,17 +207,16 @@ class MTProtoSender:
.format(self._retries))
__log__.debug('Connection success!')
state = self._connection._state
if state.auth_key is None:
plain = MTProtoPlainSender(self._connection._connection)
if not self._auth_key:
plain = MTProtoPlainSender(self._connection)
for retry in range(1, self._retries + 1):
try:
__log__.debug('New auth_key attempt {}...'.format(retry))
state.auth_key, state.time_offset =\
self._auth_key.key, self._state.time_offset =\
await authenticator.do_authentication(plain)
if self._auth_key_callback:
await self._auth_key_callback(state.auth_key)
await self._auth_key_callback(self._auth_key)
break
except (SecurityError, AssertionError) as e:
@ -292,7 +294,7 @@ class MTProtoSender:
self._reconnecting = False
# Start with a clean state (and thus session ID) to avoid old msgs
self._connection.reset_state()
self._state.reset()
retries = self._retries if self._auto_reconnect else 0
for retry in range(1, retries + 1):
@ -333,19 +335,19 @@ class MTProtoSender:
self._last_acks.append(ack)
self._pending_ack.clear()
state_list = await self._send_queue.get(
self._connection._connection.disconnected)
batch, data = await self._send_queue.get(
self._connection.disconnected)
if state_list is None:
break
if not data:
continue
try:
await self._connection.send(state_list)
await self._connection.send(data)
except Exception:
__log__.exception('Unhandled error while sending data')
continue
for state in state_list:
for state in batch:
if not isinstance(state, list):
if isinstance(state.request, TLRequest):
self._pending_state[state.msg_id] = state
@ -364,7 +366,9 @@ class MTProtoSender:
while self._user_connected and not self._reconnecting:
__log__.debug('Receiving items from the network...')
try:
message = await self._connection.recv()
# TODO Split except
body = await self._connection.recv()
message = self._state.decrypt_message_data(body)
except TypeNotFoundError as e:
__log__.info('Type %08x not found, remaining data %r',
e.invalid_constructor_id, e.remaining)
@ -388,7 +392,7 @@ class MTProtoSender:
else:
__log__.warning('Invalid buffer %s', e)
self._connection._state.auth_key = None
self._auth_key.key = None
self._start_reconnect()
return
except asyncio.IncompleteReadError:
@ -533,7 +537,7 @@ class MTProtoSender:
"""
bad_salt = message.obj
__log__.debug('Handling bad salt for message %d', bad_salt.bad_msg_id)
self._connection._state.salt = bad_salt.new_server_salt
self._state.salt = bad_salt.new_server_salt
states = self._pop_states(bad_salt.bad_msg_id)
self._send_queue.extend(states)
@ -554,16 +558,16 @@ class MTProtoSender:
if bad_msg.error_code in (16, 17):
# Sent msg_id too low or too high (respectively).
# Use the current msg_id to determine the right time offset.
to = self._connection._state.update_time_offset(
to = self._state.update_time_offset(
correct_msg_id=message.msg_id)
__log__.info('System clock is wrong, set time offset to %ds', to)
elif bad_msg.error_code == 32:
# msg_seqno too low, so just pump it up by some "large" amount
# TODO A better fix would be to start with a new fresh session ID
self._connection._state._sequence += 64
self._state._sequence += 64
elif bad_msg.error_code == 33:
# msg_seqno too high never seems to happen but just in case
self._connection._state._sequence -= 16
self._state._sequence -= 16
else:
for state in states:
state.future.set_exception(BadMessageError(bad_msg.error_code))
@ -606,7 +610,7 @@ class MTProtoSender:
"""
# TODO https://goo.gl/LMyN7A
__log__.debug('Handling new session created')
self._connection._state.salt = message.obj.server_salt
self._state.salt = message.obj.server_salt
async def _handle_ack(self, message):
"""

View File

@ -38,11 +38,17 @@ class MTProtoState:
authentication process, at which point the `MTProtoPlainSender` is better.
"""
def __init__(self, auth_key):
# Session IDs can be random on every connection
self.id = struct.unpack('q', os.urandom(8))[0]
self.auth_key = auth_key
self.time_offset = 0
self.salt = 0
self.reset()
def reset(self):
"""
Resets the state.
"""
# Session IDs can be random on every connection
self.id = struct.unpack('q', os.urandom(8))[0]
self._sequence = 0
self._last_msg_id = 0