Create a centralized Connection class, replaces TcpTransport (#112)

This commit is contained in:
Lonami Exo 2017-08-28 21:23:31 +02:00
parent bc72e52834
commit fa22a3f848
8 changed files with 143 additions and 138 deletions

View File

@ -1,4 +1,4 @@
from .mtproto_plain_sender import MtProtoPlainSender
from .authenticator import do_authentication
from .mtproto_sender import MtProtoSender
from .tcp_transport import TcpTransport
from .connection import Connection

View File

@ -9,12 +9,12 @@ from ..network import MtProtoPlainSender
from ..extensions import BinaryReader, BinaryWriter
def do_authentication(transport):
def do_authentication(connection):
"""Executes the authentication process with the Telegram servers.
If no error is raised, returns both the authorization key and the
time offset.
"""
sender = MtProtoPlainSender(transport)
sender = MtProtoPlainSender(connection)
sender.connect()
# Step 1 sending: PQ Request

View File

@ -0,0 +1,116 @@
from datetime import timedelta
from zlib import crc32
from ..extensions import BinaryWriter, TcpClient
from ..errors import InvalidChecksumError
class Connection:
def __init__(self, ip, port, mode='tcp_abridged',
proxy=None, timeout=timedelta(seconds=5)):
"""Represents an abstract connection (TCP, TCP abridged...).
'mode' may be any of 'tcp_full', 'tcp_abridged'
"""
self.ip = ip
self.port = port
self._mode = mode
self.timeout = timeout
self._send_counter = 0
# TODO Rename "TcpClient" as some sort of generic socket
self.conn = TcpClient(proxy=proxy)
if mode == 'tcp_full':
setattr(self, 'send', self._send_tcp_full)
setattr(self, 'recv', self._recv_tcp_full)
elif mode == 'tcp_abridged':
setattr(self, 'send', self._send_abridged)
setattr(self, 'recv', self._recv_abridged)
def connect(self):
self._send_counter = 0
self.conn.connect(self.ip, self.port,
timeout=round(self.timeout.seconds))
if self._mode == 'tcp_abridged':
self.conn.write(int.to_bytes(239, 1, 'little'))
def is_connected(self):
return self.conn.connected
def close(self):
self.conn.close()
def cancel_receive(self):
"""Cancels (stops) trying to receive from the
remote peer and raises a ReadCancelledError"""
self.conn.cancel_read()
def get_client_delay(self):
"""Gets the client read delay"""
return self.conn.delay
# region Receive implementations
def recv(self, **kwargs):
"""Receives and unpacks a message"""
# TODO Don't ignore kwargs['timeout']?
# Default implementation is just an error
raise ValueError('Invalid connection mode specified: ' + self._mode)
def _recv_tcp_full(self, **kwargs):
packet_length_bytes = self.conn.read(4, self.timeout)
packet_length = int.from_bytes(packet_length_bytes, 'little')
seq_bytes = self.conn.read(4, self.timeout)
seq = int.from_bytes(seq_bytes, 'little')
body = self.conn.read(packet_length - 12, self.timeout)
checksum = int.from_bytes(self.conn.read(4, self.timeout), 'little')
valid_checksum = crc32(packet_length_bytes + seq_bytes + body)
if checksum != valid_checksum:
raise InvalidChecksumError(checksum, valid_checksum)
return body
def _recv_abridged(self, **kwargs):
length = int.from_bytes(self.conn.read(1, self.timeout), 'little')
if length >= 127:
length = int.from_bytes(self.conn.read(3, self.timeout) + b'\0', 'little')
return self.conn.read(length << 2)
# endregion
# region Send implementations
def send(self, message):
"""Encapsulates and sends the given message"""
# Default implementation is just an error
raise ValueError('Invalid connection mode specified: ' + self._mode)
def _send_tcp_full(self, message):
# https://core.telegram.org/mtproto#tcp-transport
# total length, sequence number, packet and checksum (CRC32)
with BinaryWriter() as writer:
writer.write_int(len(message) + 12)
writer.write_int(self._send_counter)
writer.write(message)
writer.write_int(crc32(writer.get_bytes()), signed=False)
self._send_counter += 1
self.conn.write(writer.get_bytes())
def _send_abridged(self, message):
with BinaryWriter() as writer:
length = len(message) >> 2
if length < 127:
writer.write_byte(length)
else:
writer.write_byte(127)
writer.write(int.to_bytes(length, 3, 'little'))
writer.write(message)
self.conn.write(writer.get_bytes())
# endregion

View File

@ -6,17 +6,17 @@ from ..extensions import BinaryReader, BinaryWriter
class MtProtoPlainSender:
"""MTProto Mobile Protocol plain sender (https://core.telegram.org/mtproto/description#unencrypted-messages)"""
def __init__(self, transport):
def __init__(self, connection):
self._sequence = 0
self._time_offset = 0
self._last_msg_id = 0
self._transport = transport
self._connection = connection
def connect(self):
self._transport.connect()
self._connection.connect()
def disconnect(self):
self._transport.close()
self._connection.close()
def send(self, data):
"""Sends a plain packet (auth_key_id = 0) containing the given message body (data)"""
@ -27,11 +27,11 @@ class MtProtoPlainSender:
writer.write(data)
packet = writer.get_bytes()
self._transport.send(packet)
self._connection.send(packet)
def receive(self):
"""Receives a plain packet, returning the body of the response"""
seq, body = self._transport.receive()
body = self._connection.recv()
with BinaryReader(body) as reader:
reader.read_long() # auth_key_id
reader.read_long() # msg_id

View File

@ -16,8 +16,8 @@ logging.getLogger(__name__).addHandler(logging.NullHandler())
class MtProtoSender:
"""MTProto Mobile Protocol sender (https://core.telegram.org/mtproto/description)"""
def __init__(self, transport, session):
self.transport = transport
def __init__(self, connection, session):
self.connection = connection
self.session = session
self._logger = logging.getLogger(__name__)
@ -33,14 +33,14 @@ class MtProtoSender:
def connect(self):
"""Connects to the server"""
self.transport.connect()
self.connection.connect()
def is_connected(self):
return self.transport.is_connected()
return self.connection.is_connected()
def disconnect(self):
"""Disconnects from the server"""
self.transport.close()
self.connection.close()
# region Send and receive
@ -81,7 +81,7 @@ class MtProtoSender:
the received data). This also restores the updates thread.
An optional named parameter 'timeout' can be specified if
one desires to override 'self.transport.timeout'.
one desires to override 'self.connection.timeout'.
If 'request' is None, a single item will be read into
the 'updates' list (which cannot be None).
@ -101,7 +101,7 @@ class MtProtoSender:
while (request and not request.confirm_received) or \
(not request and not updates):
self._logger.debug('Trying to .receive() the request result...')
seq, body = self.transport.receive(**kwargs)
body = self.connection.recv(**kwargs)
message, remote_msg_id, remote_seq = self._decode_msg(body)
with BinaryReader(message) as reader:
@ -126,7 +126,7 @@ class MtProtoSender:
def cancel_receive(self):
"""Cancels any pending receive operation
by raising a ReadCancelledError"""
self.transport.cancel_receive()
self.connection.cancel_receive()
# endregion
@ -159,7 +159,7 @@ class MtProtoSender:
self.session.auth_key.key_id, signed=False)
cipher_writer.write(msg_key)
cipher_writer.write(cipher_text)
self.transport.send(cipher_writer.get_bytes())
self.connection.send(cipher_writer.get_bytes())
def _decode_msg(self, body):
"""Decodes an received encrypted message body bytes"""

View File

@ -1,111 +0,0 @@
from zlib import crc32
from datetime import timedelta
from ..errors import InvalidChecksumError
from ..extensions import TcpClient, TcpClientObfuscated
from ..extensions import BinaryWriter
class TcpTransport:
def __init__(self, ip_address, port,
proxy=None, timeout=timedelta(seconds=5)):
self.ip = ip_address
self.port = port
self.tcp_client = TcpClientObfuscated(proxy)
self.timeout = timeout
self.send_counter = 0
def connect(self):
"""Connects to the specified IP address and port"""
self.send_counter = 0
self.tcp_client.connect(self.ip, self.port,
timeout=round(self.timeout.seconds))
def is_connected(self):
return self.tcp_client.connected
# Original reference: https://core.telegram.org/mtproto#tcp-transport
# The packets are encoded as:
# total length, sequence number, packet and checksum (CRC32)
def send(self, packet):
"""Sends the given packet (bytes array) to the connected peer"""
if not self.tcp_client.connected:
raise ConnectionError('Client not connected to server.')
with BinaryWriter() as writer:
if isinstance(self.tcp_client, TcpClient):
# 12 = size_of (integer) * 3
writer.write_int(len(packet) + 12)
writer.write_int(self.send_counter)
writer.write(packet)
crc = crc32(writer.get_bytes())
writer.write_int(crc, signed=False)
self.send_counter += 1
elif isinstance(self.tcp_client, TcpClientObfuscated):
length = len(packet) >> 2
if length < 127:
writer.write_byte(length)
else:
writer.write_byte(127)
writer.write(int.to_bytes(length, 3, 'little'))
writer.write(packet)
else:
raise ValueError('Unknown client')
self.tcp_client.write(writer.get_bytes())
def receive(self, **kwargs):
"""Receives a TCP message (tuple(sequence number, body)) from the
connected peer.
If a named 'timeout' parameter is present, it will override
'self.timeout', and this can be a 'timedelta' or 'None'.
"""
if isinstance(self.tcp_client, TcpClient):
timeout = kwargs.get('timeout', self.timeout)
# First read everything we need
packet_length_bytes = self.tcp_client.read(4, timeout)
packet_length = int.from_bytes(packet_length_bytes, byteorder='little')
seq_bytes = self.tcp_client.read(4, timeout)
seq = int.from_bytes(seq_bytes, byteorder='little')
body = self.tcp_client.read(packet_length - 12, timeout)
checksum = int.from_bytes(
self.tcp_client.read(4, timeout), byteorder='little', signed=False)
# Then perform the checks
rv = packet_length_bytes + seq_bytes + body
valid_checksum = crc32(rv)
if checksum != valid_checksum:
raise InvalidChecksumError(checksum, valid_checksum)
# If we passed the tests, we can then return a valid TCP message
return seq, body
elif isinstance(self.tcp_client, TcpClientObfuscated):
packet_length = int.from_bytes(self.tcp_client.read(1), 'little')
if packet_length < 127:
return 0, self.tcp_client.read(packet_length << 2)
else:
plb = self.tcp_client.read(3)
pl = int.from_bytes(plb + b'\0', 'little') << 2
return 0, self.tcp_client.read(pl)
else:
raise ValueError('Unknown client')
def close(self):
self.tcp_client.close()
def cancel_receive(self):
"""Cancels (stops) trying to receive from the
remote peer and raises a ReadCancelledError"""
self.tcp_client.cancel_read()
def get_client_delay(self):
"""Gets the client read delay"""
return self.tcp_client.delay

View File

@ -10,7 +10,7 @@ from . import helpers as utils
from .errors import (
RPCError, FloodWaitError, FileMigrateError, TypeNotFoundError
)
from .network import authenticator, MtProtoSender, TcpTransport
from .network import authenticator, MtProtoSender, Connection
from .utils import get_appropriated_part_size
from .crypto import rsa, CdnDecrypter
@ -117,19 +117,19 @@ class TelegramBareClient:
# If ping failed, ensure we're disconnected first
self.disconnect()
transport = TcpTransport(self.session.server_address,
self.session.port,
proxy=self.proxy,
timeout=self._timeout)
connection = Connection(
self.session.server_address, self.session.port,
proxy=self.proxy, timeout=self._timeout
)
try:
if not self.session.auth_key:
self.session.auth_key, self.session.time_offset = \
authenticator.do_authentication(transport)
authenticator.do_authentication(connection)
self.session.save()
self._sender = MtProtoSender(transport, self.session)
self._sender = MtProtoSender(connection, self.session)
self._sender.connect()
# Now it's time to send an InitConnectionRequest

View File

@ -5,7 +5,7 @@ import unittest
import telethon.network.authenticator as authenticator
from telethon.extensions import TcpClient
from telethon.network import TcpTransport
from telethon.network import Connection
def run_server_echo_thread(port):
@ -38,6 +38,6 @@ class NetworkTests(unittest.TestCase):
@staticmethod
def test_authenticator():
transport = TcpTransport('149.154.167.91', 443)
transport = Connection('149.154.167.91', 443)
authenticator.do_authentication(transport)
transport.close()