mirror of
https://github.com/LonamiWebs/Telethon.git
synced 2024-11-10 19:46:36 +03:00
Use new broken MessagePacker
This commit is contained in:
parent
83f60deef9
commit
e2fe3eb503
|
@ -30,12 +30,18 @@ class AuthKey:
|
|||
self._key = self.aux_hash = self.key_id = None
|
||||
return
|
||||
|
||||
if isinstance(value, type(self)):
|
||||
self._key, self.aux_hash, self.key_id = \
|
||||
value._key, value.aux_hash, value.key_id
|
||||
return
|
||||
|
||||
self._key = value
|
||||
with BinaryReader(sha1(self._key).digest()) as reader:
|
||||
self.aux_hash = reader.read_long(signed=False)
|
||||
reader.read(4)
|
||||
self.key_id = reader.read_long(signed=False)
|
||||
|
||||
# TODO This doesn't really fit here, it's only used in authentication
|
||||
def calc_new_nonce_hash(self, new_nonce, number):
|
||||
"""
|
||||
Calculates the new nonce hash based on the current attributes.
|
||||
|
|
116
telethon/extensions/messagepacker.py
Normal file
116
telethon/extensions/messagepacker.py
Normal file
|
@ -0,0 +1,116 @@
|
|||
import asyncio
|
||||
import collections
|
||||
import io
|
||||
import logging
|
||||
import struct
|
||||
|
||||
from ..tl import TLRequest
|
||||
from ..tl.core.messagecontainer import MessageContainer
|
||||
from ..tl.core.tlmessage import TLMessage
|
||||
|
||||
__log__ = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MessagePacker:
|
||||
"""
|
||||
This class packs `RequestState` as outgoing `TLMessages`.
|
||||
|
||||
The purpose of this class is to support putting N `RequestState` into a
|
||||
queue, and then awaiting for "packed" `TLMessage` in the other end. The
|
||||
simplest case would be ``State -> TLMessage`` (1-to-1 relationship) but
|
||||
for efficiency purposes it's ``States -> Container`` (N-to-1).
|
||||
|
||||
This addresses several needs: outgoing messages will be smaller, so the
|
||||
encryption and network overhead also is smaller. It's also a central
|
||||
point where outgoing requests are put, and where ready-messages are get.
|
||||
"""
|
||||
def __init__(self, state, loop):
|
||||
self._state = state
|
||||
self._loop = loop
|
||||
self._deque = collections.deque()
|
||||
self._ready = asyncio.Event(loop=loop)
|
||||
|
||||
def append(self, state):
|
||||
self._deque.append(state)
|
||||
self._ready.set()
|
||||
|
||||
def extend(self, states):
|
||||
self._deque.extend(states)
|
||||
self._ready.set()
|
||||
|
||||
async def get(self, cancellation):
|
||||
"""
|
||||
Returns (batch, data) if one or more items could be retrieved.
|
||||
|
||||
If the cancellation occurs or only invalid items were in the
|
||||
queue, (None, None) will be returned instead.
|
||||
"""
|
||||
if not self._deque:
|
||||
self._ready.clear()
|
||||
ready = self._loop.create_task(self._ready.wait())
|
||||
try:
|
||||
done, pending = await asyncio.wait(
|
||||
[ready, cancellation],
|
||||
return_when=asyncio.FIRST_COMPLETED,
|
||||
loop=self._loop
|
||||
)
|
||||
except asyncio.CancelledError:
|
||||
done = [cancellation]
|
||||
|
||||
if cancellation in done:
|
||||
ready.cancel()
|
||||
return None, None
|
||||
|
||||
buffer = io.BytesIO()
|
||||
batch = []
|
||||
size = 0
|
||||
|
||||
# Fill a new batch to return while the size is small enough
|
||||
while self._deque:
|
||||
state = self._deque.popleft()
|
||||
size += len(state.data) + TLMessage.SIZE_OVERHEAD
|
||||
|
||||
if size <= MessageContainer.MAXIMUM_SIZE:
|
||||
# TODO Implement back using after_id
|
||||
state.msg_id = self._state.write_data_as_message(
|
||||
buffer, state.data, isinstance(state.request, TLRequest)
|
||||
)
|
||||
batch.append(state)
|
||||
__log__.debug('Assigned msg_id = %d to %s (%x)',
|
||||
state.msg_id, state.request.__class__.__name__,
|
||||
id(state.request))
|
||||
continue
|
||||
|
||||
# Put the item back since it can't be sent in this batch
|
||||
self._deque.appendleft(state)
|
||||
if batch:
|
||||
break
|
||||
|
||||
# If a single message exceeds the maximum size, then the
|
||||
# message payload cannot be sent. Telegram would forcibly
|
||||
# close the connection; message would never be confirmed.
|
||||
state.future.set_exception(
|
||||
ValueError('Request payload is too big'))
|
||||
|
||||
size = 0
|
||||
continue
|
||||
|
||||
if not batch:
|
||||
return None, None
|
||||
|
||||
if len(batch) > 1:
|
||||
# Inlined code to pack several messages into a container
|
||||
data = struct.pack(
|
||||
'<Ii', MessageContainer.CONSTRUCTOR_ID, len(batch)
|
||||
) + buffer.getvalue()
|
||||
buffer = io.BytesIO()
|
||||
container_id = self._state.write_data_as_message(
|
||||
buffer, data, content_related=False
|
||||
)
|
||||
for s in batch:
|
||||
s.container_id = container_id
|
||||
|
||||
data = buffer.getvalue()
|
||||
__log__.debug('Packed %d message(s) in %d bytes for sending',
|
||||
len(batch), len(data))
|
||||
return batch, data
|
|
@ -69,6 +69,7 @@ def get_password_hash(pw, current_salt):
|
|||
|
||||
# region Custom Classes
|
||||
|
||||
|
||||
class TotalList(list):
|
||||
"""
|
||||
A list with an extra `total` property, which may not match its `len`
|
||||
|
@ -88,45 +89,4 @@ class TotalList(list):
|
|||
', '.join(repr(x) for x in self), self.total)
|
||||
|
||||
|
||||
class _ReadyQueue:
|
||||
"""
|
||||
A queue list that supports an arbitrary cancellation token for `get`.
|
||||
"""
|
||||
def __init__(self, loop):
|
||||
self._list = []
|
||||
self._loop = loop
|
||||
self._ready = asyncio.Event(loop=loop)
|
||||
|
||||
def append(self, item):
|
||||
self._list.append(item)
|
||||
self._ready.set()
|
||||
|
||||
def extend(self, items):
|
||||
self._list.extend(items)
|
||||
self._ready.set()
|
||||
|
||||
async def get(self, cancellation):
|
||||
"""
|
||||
Returns a list of all the items added to the queue until now and
|
||||
clears the list from the queue itself. Returns ``None`` if cancelled.
|
||||
"""
|
||||
ready = self._loop.create_task(self._ready.wait())
|
||||
try:
|
||||
done, pending = await asyncio.wait(
|
||||
[ready, cancellation],
|
||||
return_when=asyncio.FIRST_COMPLETED,
|
||||
loop=self._loop
|
||||
)
|
||||
except asyncio.CancelledError:
|
||||
done = [cancellation]
|
||||
|
||||
if cancellation in done:
|
||||
ready.cancel()
|
||||
return None
|
||||
|
||||
result = self._list
|
||||
self._list = []
|
||||
self._ready.clear()
|
||||
return result
|
||||
|
||||
# endregion
|
||||
|
|
|
@ -1,158 +0,0 @@
|
|||
import io
|
||||
import logging
|
||||
import struct
|
||||
|
||||
from .mtprotostate import MTProtoState
|
||||
from ..tl import TLRequest
|
||||
from ..tl.core.tlmessage import TLMessage
|
||||
from ..tl.core.messagecontainer import MessageContainer
|
||||
|
||||
__log__ = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MTProtoLayer:
|
||||
"""
|
||||
This class is the message encryption layer between the methods defined
|
||||
in the schema and the response objects. It also holds the necessary state
|
||||
necessary for this encryption to happen.
|
||||
|
||||
The `connection` parameter is through which these messages will be sent
|
||||
and received.
|
||||
|
||||
The `auth_key` must be a valid authorization key which will be used to
|
||||
encrypt these messages. This class is not responsible for generating them.
|
||||
"""
|
||||
def __init__(self, connection, auth_key):
|
||||
self._connection = connection
|
||||
self._state = MTProtoState(auth_key)
|
||||
|
||||
def connect(self, timeout=None):
|
||||
"""
|
||||
Wrapper for ``self._connection.connect()``.
|
||||
"""
|
||||
return self._connection.connect(timeout=timeout)
|
||||
|
||||
def disconnect(self):
|
||||
"""
|
||||
Wrapper for ``self._connection.disconnect()``.
|
||||
"""
|
||||
self._connection.disconnect()
|
||||
|
||||
def reset_state(self):
|
||||
self._state = MTProtoState(self._state.auth_key)
|
||||
|
||||
async def send(self, state_list):
|
||||
"""
|
||||
The list of `RequestState` that will be sent. They will
|
||||
be updated with their new message and container IDs.
|
||||
|
||||
Nested lists imply an order is required for the messages in them.
|
||||
Message containers will be used if there is more than one item.
|
||||
"""
|
||||
for data in filter(None, self._pack_state_list(state_list)):
|
||||
await self._connection.send(self._state.encrypt_message_data(data))
|
||||
|
||||
async def recv(self):
|
||||
"""
|
||||
Reads a single message from the network, decrypts it and returns it.
|
||||
"""
|
||||
body = await self._connection.recv()
|
||||
return self._state.decrypt_message_data(body)
|
||||
|
||||
def _pack_state_list(self, state_list):
|
||||
"""
|
||||
The list of `RequestState` that will be sent. They will
|
||||
be updated with their new message and container IDs.
|
||||
|
||||
Packs all their serialized data into a message (possibly
|
||||
nested inside another message and message container) and
|
||||
returns the serialized message data.
|
||||
"""
|
||||
# Note that the simplest case is writing a single query data into
|
||||
# a message, and returning the message data and ID. For efficiency
|
||||
# purposes this method supports more than one message and automatically
|
||||
# uses containers if deemed necessary.
|
||||
#
|
||||
# Technically the message and message container classes could be used
|
||||
# to store and serialize the data. However, to keep the context local
|
||||
# and relevant to the only place where such feature is actually used,
|
||||
# this is not done.
|
||||
#
|
||||
# When iterating over the state_list there are two branches, one
|
||||
# being just a state and the other being a list so the inner states
|
||||
# depend on each other. In either case, if the packed size exceeds
|
||||
# the maximum container size, it must be sent. This code is non-
|
||||
# trivial so it has been factored into an inner function.
|
||||
#
|
||||
# A new buffer instance will be used once the size should be "flushed"
|
||||
buffer = io.BytesIO()
|
||||
# The batch of requests sent in a single buffer-flush. We need to
|
||||
# remember which states were written to set their container ID.
|
||||
batch = []
|
||||
# The currently written size. Reset when it exceeds the maximum.
|
||||
size = 0
|
||||
|
||||
def write_state(state, after_id=None):
|
||||
nonlocal buffer, batch, size
|
||||
if state:
|
||||
batch.append(state)
|
||||
size += len(state.data) + TLMessage.SIZE_OVERHEAD
|
||||
|
||||
# Flush whenever the current size exceeds the maximum,
|
||||
# or if there's no state, which indicates force flush.
|
||||
if not state or size > MessageContainer.MAXIMUM_SIZE:
|
||||
size -= MessageContainer.MAXIMUM_SIZE
|
||||
if len(batch) > 1:
|
||||
# Inlined code to pack several messages into a container
|
||||
data = struct.pack(
|
||||
'<Ii', MessageContainer.CONSTRUCTOR_ID, len(batch)
|
||||
) + buffer.getvalue()
|
||||
buffer = io.BytesIO()
|
||||
container_id = self._state.write_data_as_message(
|
||||
buffer, data, content_related=False
|
||||
)
|
||||
for s in batch:
|
||||
s.container_id = container_id
|
||||
|
||||
# At this point it's either a single msg or a msg + container
|
||||
data = buffer.getvalue()
|
||||
__log__.debug('Packed %d message(s) in %d bytes for sending',
|
||||
len(batch), len(data))
|
||||
batch.clear()
|
||||
buffer = io.BytesIO()
|
||||
return data
|
||||
|
||||
if not state:
|
||||
return # Just forcibly flushing
|
||||
|
||||
# If even after flushing it still exceeds the maximum size,
|
||||
# this message payload cannot be sent. Telegram would forcibly
|
||||
# close the connection, and the message would never be confirmed.
|
||||
if size > MessageContainer.MAXIMUM_SIZE:
|
||||
state.future.set_exception(
|
||||
ValueError('Request payload is too big'))
|
||||
return
|
||||
|
||||
# This is the only requirement to make this work.
|
||||
state.msg_id = self._state.write_data_as_message(
|
||||
buffer, state.data, isinstance(state.request, TLRequest),
|
||||
after_id=after_id
|
||||
)
|
||||
__log__.debug('Assigned msg_id = %d to %s (%x)',
|
||||
state.msg_id, state.request.__class__.__name__,
|
||||
id(state.request))
|
||||
|
||||
# TODO Yield in the inner loop -> Telegram "Invalid container". Why?
|
||||
for state in state_list:
|
||||
if not isinstance(state, list):
|
||||
yield write_state(state)
|
||||
else:
|
||||
after_id = None
|
||||
for s in state:
|
||||
yield write_state(s, after_id)
|
||||
after_id = s.msg_id
|
||||
|
||||
yield write_state(None)
|
||||
|
||||
def __str__(self):
|
||||
return str(self._connection)
|
|
@ -3,9 +3,10 @@ import collections
|
|||
import logging
|
||||
|
||||
from . import authenticator
|
||||
from .mtprotolayer import MTProtoLayer
|
||||
from ..extensions.messagepacker import MessagePacker
|
||||
from .mtprotoplainsender import MTProtoPlainSender
|
||||
from .requeststate import RequestState
|
||||
from .mtprotostate import MTProtoState
|
||||
from ..tl.tlobject import TLRequest
|
||||
from .. import utils
|
||||
from ..errors import (
|
||||
|
@ -13,7 +14,6 @@ from ..errors import (
|
|||
InvalidChecksumError, rpc_message_to_error
|
||||
)
|
||||
from ..extensions import BinaryReader
|
||||
from ..helpers import _ReadyQueue
|
||||
from ..tl.core import RpcResult, MessageContainer, GzipPacked
|
||||
from ..tl.functions.auth import LogOutRequest
|
||||
from ..tl.types import (
|
||||
|
@ -21,7 +21,7 @@ from ..tl.types import (
|
|||
MsgNewDetailedInfo, NewSessionCreated, MsgDetailedInfo, MsgsStateReq,
|
||||
MsgsStateInfo, MsgsAllInfo, MsgResendReq, upload
|
||||
)
|
||||
from ..utils import AsyncClassWrapper
|
||||
from ..crypto import AuthKey
|
||||
|
||||
__log__ = logging.getLogger(__name__)
|
||||
|
||||
|
@ -43,9 +43,9 @@ class MTProtoSender:
|
|||
"""
|
||||
def __init__(self, loop, *,
|
||||
retries=5, auto_reconnect=True, connect_timeout=None,
|
||||
update_callback=None,
|
||||
update_callback=None, auth_key=None,
|
||||
auth_key_callback=None, auto_reconnect_callback=None):
|
||||
self._connection = None # MTProtoLayer, a.k.a. encrypted connection
|
||||
self._connection = None
|
||||
self._loop = loop
|
||||
self._retries = retries
|
||||
self._auto_reconnect = auto_reconnect
|
||||
|
@ -68,10 +68,13 @@ class MTProtoSender:
|
|||
self._send_loop_handle = None
|
||||
self._recv_loop_handle = None
|
||||
|
||||
# Preserving the references of the AuthKey and state is important
|
||||
self._auth_key = auth_key or AuthKey(None)
|
||||
self._state = MTProtoState(self._auth_key)
|
||||
|
||||
# Outgoing messages are put in a queue and sent in a batch.
|
||||
# Note that here we're also storing their ``_RequestState``.
|
||||
# Note that it may also store lists (implying order must be kept).
|
||||
self._send_queue = _ReadyQueue(self._loop)
|
||||
self._send_queue = MessagePacker(self._state, self._loop)
|
||||
|
||||
# Sent states are remembered until a response is received.
|
||||
self._pending_state = {}
|
||||
|
@ -112,7 +115,7 @@ class MTProtoSender:
|
|||
__log__.info('User is already connected!')
|
||||
return
|
||||
|
||||
self._connection = MTProtoLayer(connection, auth_key)
|
||||
self._connection = connection
|
||||
self._user_connected = True
|
||||
await self._connect()
|
||||
|
||||
|
@ -204,17 +207,16 @@ class MTProtoSender:
|
|||
.format(self._retries))
|
||||
|
||||
__log__.debug('Connection success!')
|
||||
state = self._connection._state
|
||||
if state.auth_key is None:
|
||||
plain = MTProtoPlainSender(self._connection._connection)
|
||||
if not self._auth_key:
|
||||
plain = MTProtoPlainSender(self._connection)
|
||||
for retry in range(1, self._retries + 1):
|
||||
try:
|
||||
__log__.debug('New auth_key attempt {}...'.format(retry))
|
||||
state.auth_key, state.time_offset =\
|
||||
self._auth_key.key, self._state.time_offset =\
|
||||
await authenticator.do_authentication(plain)
|
||||
|
||||
if self._auth_key_callback:
|
||||
await self._auth_key_callback(state.auth_key)
|
||||
await self._auth_key_callback(self._auth_key)
|
||||
|
||||
break
|
||||
except (SecurityError, AssertionError) as e:
|
||||
|
@ -292,7 +294,7 @@ class MTProtoSender:
|
|||
self._reconnecting = False
|
||||
|
||||
# Start with a clean state (and thus session ID) to avoid old msgs
|
||||
self._connection.reset_state()
|
||||
self._state.reset()
|
||||
|
||||
retries = self._retries if self._auto_reconnect else 0
|
||||
for retry in range(1, retries + 1):
|
||||
|
@ -333,19 +335,19 @@ class MTProtoSender:
|
|||
self._last_acks.append(ack)
|
||||
self._pending_ack.clear()
|
||||
|
||||
state_list = await self._send_queue.get(
|
||||
self._connection._connection.disconnected)
|
||||
batch, data = await self._send_queue.get(
|
||||
self._connection.disconnected)
|
||||
|
||||
if state_list is None:
|
||||
break
|
||||
if not data:
|
||||
continue
|
||||
|
||||
try:
|
||||
await self._connection.send(state_list)
|
||||
await self._connection.send(data)
|
||||
except Exception:
|
||||
__log__.exception('Unhandled error while sending data')
|
||||
continue
|
||||
|
||||
for state in state_list:
|
||||
for state in batch:
|
||||
if not isinstance(state, list):
|
||||
if isinstance(state.request, TLRequest):
|
||||
self._pending_state[state.msg_id] = state
|
||||
|
@ -364,7 +366,9 @@ class MTProtoSender:
|
|||
while self._user_connected and not self._reconnecting:
|
||||
__log__.debug('Receiving items from the network...')
|
||||
try:
|
||||
message = await self._connection.recv()
|
||||
# TODO Split except
|
||||
body = await self._connection.recv()
|
||||
message = self._state.decrypt_message_data(body)
|
||||
except TypeNotFoundError as e:
|
||||
__log__.info('Type %08x not found, remaining data %r',
|
||||
e.invalid_constructor_id, e.remaining)
|
||||
|
@ -388,7 +392,7 @@ class MTProtoSender:
|
|||
else:
|
||||
__log__.warning('Invalid buffer %s', e)
|
||||
|
||||
self._connection._state.auth_key = None
|
||||
self._auth_key.key = None
|
||||
self._start_reconnect()
|
||||
return
|
||||
except asyncio.IncompleteReadError:
|
||||
|
@ -533,7 +537,7 @@ class MTProtoSender:
|
|||
"""
|
||||
bad_salt = message.obj
|
||||
__log__.debug('Handling bad salt for message %d', bad_salt.bad_msg_id)
|
||||
self._connection._state.salt = bad_salt.new_server_salt
|
||||
self._state.salt = bad_salt.new_server_salt
|
||||
states = self._pop_states(bad_salt.bad_msg_id)
|
||||
self._send_queue.extend(states)
|
||||
|
||||
|
@ -554,16 +558,16 @@ class MTProtoSender:
|
|||
if bad_msg.error_code in (16, 17):
|
||||
# Sent msg_id too low or too high (respectively).
|
||||
# Use the current msg_id to determine the right time offset.
|
||||
to = self._connection._state.update_time_offset(
|
||||
to = self._state.update_time_offset(
|
||||
correct_msg_id=message.msg_id)
|
||||
__log__.info('System clock is wrong, set time offset to %ds', to)
|
||||
elif bad_msg.error_code == 32:
|
||||
# msg_seqno too low, so just pump it up by some "large" amount
|
||||
# TODO A better fix would be to start with a new fresh session ID
|
||||
self._connection._state._sequence += 64
|
||||
self._state._sequence += 64
|
||||
elif bad_msg.error_code == 33:
|
||||
# msg_seqno too high never seems to happen but just in case
|
||||
self._connection._state._sequence -= 16
|
||||
self._state._sequence -= 16
|
||||
else:
|
||||
for state in states:
|
||||
state.future.set_exception(BadMessageError(bad_msg.error_code))
|
||||
|
@ -606,7 +610,7 @@ class MTProtoSender:
|
|||
"""
|
||||
# TODO https://goo.gl/LMyN7A
|
||||
__log__.debug('Handling new session created')
|
||||
self._connection._state.salt = message.obj.server_salt
|
||||
self._state.salt = message.obj.server_salt
|
||||
|
||||
async def _handle_ack(self, message):
|
||||
"""
|
||||
|
|
|
@ -38,11 +38,17 @@ class MTProtoState:
|
|||
authentication process, at which point the `MTProtoPlainSender` is better.
|
||||
"""
|
||||
def __init__(self, auth_key):
|
||||
# Session IDs can be random on every connection
|
||||
self.id = struct.unpack('q', os.urandom(8))[0]
|
||||
self.auth_key = auth_key
|
||||
self.time_offset = 0
|
||||
self.salt = 0
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
"""
|
||||
Resets the state.
|
||||
"""
|
||||
# Session IDs can be random on every connection
|
||||
self.id = struct.unpack('q', os.urandom(8))[0]
|
||||
self._sequence = 0
|
||||
self._last_msg_id = 0
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user