Create a new MTProtoSender structure and its foundation

This means that the TcpClient and the Connection (currently only
ConnectionTcpFull) will no longer be concerned about handling
errors, but the MTProtoSender will.

The foundation of the library will now be based on asyncio.
This commit is contained in:
Lonami Exo 2018-06-06 20:41:01 +02:00
parent 4bdc28a775
commit e469258ab9
6 changed files with 284 additions and 143 deletions

View File

@ -1,31 +1,31 @@
"""
This module holds a rough implementation of the C# TCP client.
This class is **not** safe across several tasks since partial reads
may be ``await``'ed before being able to return the exact byte count.
This class is also not concerned about disconnections or retries of
any sort, nor any other kind of errors such as connecting twice.
"""
import errno
import asyncio
import logging
import socket
import time
from datetime import timedelta
from io import BytesIO, BufferedWriter
from threading import Lock
from io import BytesIO
try:
import socks
except ImportError:
socks = None
MAX_TIMEOUT = 15 # in seconds
CONN_RESET_ERRNOS = {
errno.EBADF, errno.ENOTSOCK, errno.ENETUNREACH,
errno.EINVAL, errno.ENOTCONN
}
__log__ = logging.getLogger(__name__)
# TODO Except asyncio.TimeoutError, ConnectionError, OSError...
# ...for connect, write and read (in the upper levels, not here)
class TcpClient:
"""A simple TCP client to ease the work with sockets and proxies."""
def __init__(self, proxy=None, timeout=timedelta(seconds=5)):
def __init__(self, proxy=None, timeout=5):
"""
Initializes the TCP client.
@ -34,31 +34,32 @@ class TcpClient:
"""
self.proxy = proxy
self._socket = None
self._closing_lock = Lock()
self._loop = asyncio.get_event_loop()
if isinstance(timeout, timedelta):
self.timeout = timeout.seconds
elif isinstance(timeout, (int, float)):
if isinstance(timeout, (int, float)):
self.timeout = float(timeout)
elif hasattr(timeout, 'seconds'):
self.timeout = float(timeout.seconds)
else:
raise TypeError('Invalid timeout type: {}'.format(type(timeout)))
def _recreate_socket(self, mode):
if self.proxy is None:
self._socket = socket.socket(mode, socket.SOCK_STREAM)
@staticmethod
def _create_socket(mode, proxy):
if proxy is None:
s = socket.socket(mode, socket.SOCK_STREAM)
else:
import socks
self._socket = socks.socksocket(mode, socket.SOCK_STREAM)
if type(self.proxy) is dict:
self._socket.set_proxy(**self.proxy)
s = socks.socksocket(mode, socket.SOCK_STREAM)
if isinstance(proxy, dict):
s.set_proxy(**proxy)
else: # tuple, list, etc.
self._socket.set_proxy(*self.proxy)
s.set_proxy(*proxy)
s.setblocking(False)
return s
self._socket.settimeout(self.timeout)
def connect(self, ip, port):
async def connect(self, ip, port):
"""
Tries connecting forever to IP:port unless an OSError is raised.
Tries connecting to IP:port.
:param ip: the IP to connect to.
:param port: the port to connect to.
@ -69,136 +70,116 @@ class TcpClient:
else:
mode, address = socket.AF_INET, (ip, port)
timeout = 1
while True:
try:
while not self._socket:
self._recreate_socket(mode)
if self._socket is None:
self._socket = self._create_socket(mode, self.proxy)
self._socket.connect(address)
break # Successful connection, stop retrying to connect
except OSError as e:
__log__.info('OSError "%s" raised while connecting', e)
# Stop retrying to connect if proxy connection error occurred
if socks and isinstance(e, socks.ProxyConnectionError):
raise
# There are some errors that we know how to handle, and
# the loop will allow us to retry
if e.errno in (errno.EBADF, errno.ENOTSOCK, errno.EINVAL,
errno.ECONNREFUSED, # Windows-specific follow
getattr(errno, 'WSAEACCES', None)):
# Bad file descriptor, i.e. socket was closed, set it
# to none to recreate it on the next iteration
self._socket = None
time.sleep(timeout)
timeout *= 2
if timeout > MAX_TIMEOUT:
raise
else:
raise
asyncio.wait_for(self._loop.sock_connect(self._socket, address),
self.timeout, loop=self._loop)
def _get_connected(self):
@property
def is_connected(self):
"""Determines whether the client is connected or not."""
return self._socket is not None and self._socket.fileno() >= 0
connected = property(fget=_get_connected)
def close(self):
"""Closes the connection."""
if self._closing_lock.locked():
# Already closing, no need to close again (avoid None.close())
return
with self._closing_lock:
try:
if self._socket is not None:
try:
self._socket.shutdown(socket.SHUT_RDWR)
self._socket.close()
except OSError:
pass # Ignore ENOTCONN, EBADF, and any other error when closing
finally:
self._socket = None
def write(self, data):
async def write(self, data):
"""
Writes (sends) the specified bytes to the connected peer.
:param data: the data to send.
"""
if self._socket is None:
self._raise_connection_reset(None)
if not self.is_connected:
raise ConnectionError()
# TODO Timeout may be an issue when sending the data, Changed in v3.5:
# The socket timeout is now the maximum total duration to send all data.
try:
self._socket.sendall(data)
except socket.timeout as e:
__log__.debug('socket.timeout "%s" while writing data', e)
raise TimeoutError() from e
except ConnectionError as e:
__log__.info('ConnectionError "%s" while writing data', e)
self._raise_connection_reset(e)
except OSError as e:
__log__.info('OSError "%s" while writing data', e)
if e.errno in CONN_RESET_ERRNOS:
self._raise_connection_reset(e)
else:
raise
await asyncio.wait_for(
self.sock_sendall(data),
timeout=self.timeout,
loop=self._loop
)
def read(self, size):
async def read(self, size):
"""
Reads (receives) a whole block of size bytes from the connected peer.
:param size: the size of the block to be read.
:return: the read data with len(data) == size.
"""
if self._socket is None:
self._raise_connection_reset(None)
if not self.is_connected:
raise ConnectionError()
with BufferedWriter(BytesIO(), buffer_size=size) as buffer:
with BytesIO() as buffer:
bytes_left = size
while bytes_left != 0:
try:
partial = self._socket.recv(bytes_left)
except socket.timeout as e:
# These are somewhat common if the server has nothing
# to send to us, so use a lower logging priority.
if bytes_left < size:
__log__.warning(
'socket.timeout "%s" when %d/%d had been received',
e, size - bytes_left, size
partial = await asyncio.wait_for(
self.sock_recv(bytes_left),
timeout=self.timeout,
loop=self._loop
)
else:
__log__.debug(
'socket.timeout "%s" while reading data', e
)
raise TimeoutError() from e
except ConnectionError as e:
__log__.info('ConnectionError "%s" while reading data', e)
self._raise_connection_reset(e)
except OSError as e:
if e.errno != errno.EBADF and self._closing_lock.locked():
# Ignore bad file descriptor while closing
__log__.info('OSError "%s" while reading data', e)
if e.errno in CONN_RESET_ERRNOS:
self._raise_connection_reset(e)
else:
raise
if len(partial) == 0:
self._raise_connection_reset(None)
if not partial == 0:
raise ConnectionResetError()
buffer.write(partial)
bytes_left -= len(partial)
# If everything went fine, return the read bytes
buffer.flush()
return buffer.raw.getvalue()
return buffer.getvalue()
def _raise_connection_reset(self, original):
"""Disconnects the client and raises ConnectionResetError."""
self.close() # Connection reset -> flag as socket closed
raise ConnectionResetError('The server has closed the connection.')\
from original
# Due to recent https://github.com/python/cpython/pull/4386
# Credit to @andr-04 for his original implementation
def sock_recv(self, n):
fut = self._loop.create_future()
self._sock_recv(fut, None, n)
return fut
def _sock_recv(self, fut, registered_fd, n):
if registered_fd is not None:
self._loop.remove_reader(registered_fd)
if fut.cancelled():
return
try:
data = self._socket.recv(n)
except (BlockingIOError, InterruptedError):
fd = self._socket.fileno()
self._loop.add_reader(fd, self._sock_recv, fut, fd, n)
except Exception as exc:
fut.set_exception(exc)
else:
fut.set_result(data)
def sock_sendall(self, data):
fut = self._loop.create_future()
if data:
self._sock_sendall(fut, None, data)
else:
fut.set_result(None)
return fut
def _sock_sendall(self, fut, registered_fd, data):
if registered_fd:
self._loop.remove_writer(registered_fd)
if fut.cancelled():
return
try:
n = self._socket.send(data)
except (BlockingIOError, InterruptedError):
n = 0
except Exception as exc:
fut.set_exception(exc)
return
if n == len(data):
fut.set_result(None)
else:
if n:
data = data[n:]
fd = self._socket.fileno()
self._loop.add_writer(fd, self._sock_sendall, fut, fd, data)

View File

@ -3,9 +3,9 @@ import os
import struct
from hashlib import sha1, sha256
from telethon.crypto import AES
from telethon.errors import SecurityError
from telethon.extensions import BinaryReader
from .crypto import AES
from .errors import SecurityError, BrokenAuthKeyError
from .extensions import BinaryReader
# region Multiple utilities
@ -46,15 +46,22 @@ def pack_message(session, message):
return key_id + msg_key + AES.encrypt_ige(data + padding, aes_key, aes_iv)
def unpack_message(session, reader):
def unpack_message(session, body):
"""Unpacks a message following MtProto 2.0 guidelines"""
# See https://core.telegram.org/mtproto/description
if reader.read_long(signed=False) != session.auth_key.key_id:
if len(body) < 8:
if body == b'l\xfe\xff\xff':
raise BrokenAuthKeyError()
else:
raise BufferError("Can't decode packet ({})".format(body))
key_id = struct.unpack('<Q', body[:8])[0]
if key_id != session.auth_key.key_id:
raise SecurityError('Server replied with an invalid auth key')
msg_key = reader.read(16)
msg_key = body[8:24]
aes_key, aes_iv = calc_key(session.auth_key.key, msg_key, False)
data = BinaryReader(AES.decrypt_ige(reader.read(), aes_key, aes_iv))
data = BinaryReader(AES.decrypt_ige(body[24:], aes_key, aes_iv))
data.read_long() # remote_salt
if data.read_long() != session.id:

View File

@ -1,5 +1,14 @@
"""
This module holds the abstract `Connection` class.
The `Connection.send` and `Connection.recv` methods need **not** to be
safe across several tasks and may use any amount of ``await`` keywords.
The code using these `Connection`'s should be responsible for using
an ``async with asyncio.Lock:`` block when calling said methods.
Said subclasses need not to worry about reconnecting either, and
should let the errors propagate instead.
"""
import abc
from datetime import timedelta
@ -23,7 +32,7 @@ class Connection(abc.ABC):
self._timeout = timeout
@abc.abstractmethod
def connect(self, ip, port):
async def connect(self, ip, port):
raise NotImplementedError
@abc.abstractmethod
@ -41,7 +50,7 @@ class Connection(abc.ABC):
raise NotImplementedError
@abc.abstractmethod
def close(self):
async def close(self):
"""Closes the connection."""
raise NotImplementedError
@ -51,11 +60,11 @@ class Connection(abc.ABC):
raise NotImplementedError
@abc.abstractmethod
def recv(self):
async def recv(self):
"""Receives and unpacks a message"""
raise NotImplementedError
@abc.abstractmethod
def send(self, message):
async def send(self, message):
"""Encapsulates and sends the given message"""
raise NotImplementedError

View File

@ -20,7 +20,7 @@ class ConnectionTcpFull(Connection):
self.read = self.conn.read
self.write = self.conn.write
def connect(self, ip, port):
async def connect(self, ip, port):
try:
self.conn.connect(ip, port)
except OSError as e:
@ -37,13 +37,13 @@ class ConnectionTcpFull(Connection):
def is_connected(self):
return self.conn.connected
def close(self):
async def close(self):
self.conn.close()
def clone(self):
return ConnectionTcpFull(self._proxy, self._timeout)
def recv(self):
async def recv(self):
packet_len_seq = self.read(8) # 4 and 4
packet_len, seq = struct.unpack('<ii', packet_len_seq)
body = self.read(packet_len - 12)
@ -55,7 +55,7 @@ class ConnectionTcpFull(Connection):
return body
def send(self, message):
async def send(self, message):
# https://core.telegram.org/mtproto#tcp-transport
# total length, sequence number, packet and checksum (CRC32)
length = len(message) + 12

View File

@ -0,0 +1,144 @@
import asyncio
import logging
from .connection import ConnectionTcpFull
from .. import helpers
from ..extensions import BinaryReader
from ..tl import TLMessage, MessageContainer, GzipPacked
from ..tl.types import (
MsgsAck, Pong, BadServerSalt, BadMsgNotification, FutureSalts,
MsgNewDetailedInfo, NewSessionCreated, MsgDetailedInfo
)
__log__ = logging.getLogger(__name__)
# TODO Create some kind of "ReconnectionPolicy" that allows specifying
# what should be done in case of some errors, with some sane defaults.
# For instance, should all messages be set with an error upon network
# loss? Should we try reconnecting forever? A certain amount of times?
# A timeout? What about recoverable errors, like connection reset?
class MTProtoSender:
def __init__(self, session):
self.session = session
self._connection = ConnectionTcpFull()
self._user_connected = False
# Send and receive calls must be atomic
self._send_lock = asyncio.Lock()
self._recv_lock = asyncio.Lock()
# Sending something shouldn't block
self._send_queue = asyncio.Queue()
# Telegram responds to messages out of order. Keep
# {id: Message} to set their Future result upon arrival.
self._pending_messages = {}
# We need to acknowledge every response from Telegram
self._pending_ack = set()
# Jump table from response ID to method that handles it
self._handlers = {
0xf35c6d01: self._handle_rpc_result,
MessageContainer.CONSTRUCTOR_ID: self._handle_container,
GzipPacked.CONSTRUCTOR_ID: self._handle_gzip_packed,
Pong.CONSTRUCTOR_ID: self._handle_pong,
BadServerSalt.CONSTRUCTOR_ID: self._handle_bad_server_salt,
BadMsgNotification.CONSTRUCTOR_ID: self._handle_bad_notification,
MsgDetailedInfo.CONSTRUCTOR_ID: self._handle_detailed_info,
MsgNewDetailedInfo.CONSTRUCTOR_ID: self._handle_new_detailed_info,
NewSessionCreated.CONSTRUCTOR_ID: self._handle_new_session_created,
MsgsAck.CONSTRUCTOR_ID: self._handle_ack,
FutureSalts.CONSTRUCTOR_ID: self._handle_future_salts
}
# Public API
async def connect(self, ip, port):
self._user_connected = True
async with self._send_lock:
await self._connection.connect(ip, port)
async def disconnect(self):
self._user_connected = False
try:
async with self._send_lock:
await self._connection.close()
except:
__log__.exception('Ignoring exception upon disconnection')
async def send(self, request):
# TODO Should the asyncio.Future creation belong here?
request.result = asyncio.Future()
message = TLMessage(self.session, request)
self._pending_messages[message.msg_id] = message
await self._send_queue.put(message)
# Loops
async def _send_loop(self):
while self._user_connected:
# TODO If there's more than one item, send them all at once
body = helpers.pack_message(
self.session, await self._send_queue.get())
# TODO Handle exceptions
async with self._send_lock:
await self._connection.send(body)
async def _recv_loop(self):
while self._user_connected:
# TODO Handle exceptions
async with self._recv_lock:
body = await self._connection.recv()
# TODO Check salt, session_id and sequence_number
message, remote_msg_id, remote_seq = helpers.unpack_message(
self.session, body)
self._pending_ack.add(remote_msg_id)
with BinaryReader(message) as reader:
code = reader.read_int(signed=False)
reader.seek(-4)
handler = self._handlers.get(code)
if handler:
handler(remote_msg_id, remote_seq, reader)
else:
pass # TODO Process updates
# Response Handlers
def _handle_rpc_result(self, msg_id, seq, reader):
raise NotImplementedError
def _handle_container(self, msg_id, seq, reader):
raise NotImplementedError
def _handle_gzip_packed(self, msg_id, seq, reader):
raise NotImplementedError
def _handle_pong(self, msg_id, seq, reader):
raise NotImplementedError
def _handle_bad_server_salt(self, msg_id, seq, reader):
raise NotImplementedError
def _handle_bad_notification(self, msg_id, seq, reader):
raise NotImplementedError
def _handle_detailed_info(self, msg_id, seq, reader):
raise NotImplementedError
def _handle_new_detailed_info(self, msg_id, seq, reader):
raise NotImplementedError
def _handle_new_session_created(self, msg_id, seq, reader):
raise NotImplementedError
def _handle_ack(self, msg_id, seq, reader):
raise NotImplementedError
def _handle_future_salts(self, msg_id, seq, reader):
raise NotImplementedError

View File

@ -6,7 +6,7 @@ from threading import Event
class TLObject:
def __init__(self):
self.rpc_error = None
self.result = None
self.result = None # An asyncio.Future set later
# These should be overrode
self.content_related = False # Only requests/functions/queries are