mirror of
https://github.com/LonamiWebs/Telethon.git
synced 2025-01-24 00:04:14 +03:00
Create a centralized Connection class, replaces TcpTransport (#112)
This commit is contained in:
parent
bc72e52834
commit
fa22a3f848
|
@ -1,4 +1,4 @@
|
|||
from .mtproto_plain_sender import MtProtoPlainSender
|
||||
from .authenticator import do_authentication
|
||||
from .mtproto_sender import MtProtoSender
|
||||
from .tcp_transport import TcpTransport
|
||||
from .connection import Connection
|
||||
|
|
|
@ -9,12 +9,12 @@ from ..network import MtProtoPlainSender
|
|||
from ..extensions import BinaryReader, BinaryWriter
|
||||
|
||||
|
||||
def do_authentication(transport):
|
||||
def do_authentication(connection):
|
||||
"""Executes the authentication process with the Telegram servers.
|
||||
If no error is raised, returns both the authorization key and the
|
||||
time offset.
|
||||
"""
|
||||
sender = MtProtoPlainSender(transport)
|
||||
sender = MtProtoPlainSender(connection)
|
||||
sender.connect()
|
||||
|
||||
# Step 1 sending: PQ Request
|
||||
|
|
116
telethon/network/connection.py
Normal file
116
telethon/network/connection.py
Normal file
|
@ -0,0 +1,116 @@
|
|||
from datetime import timedelta
|
||||
from zlib import crc32
|
||||
|
||||
from ..extensions import BinaryWriter, TcpClient
|
||||
from ..errors import InvalidChecksumError
|
||||
|
||||
|
||||
class Connection:
|
||||
def __init__(self, ip, port, mode='tcp_abridged',
|
||||
proxy=None, timeout=timedelta(seconds=5)):
|
||||
"""Represents an abstract connection (TCP, TCP abridged...).
|
||||
'mode' may be any of 'tcp_full', 'tcp_abridged'
|
||||
"""
|
||||
self.ip = ip
|
||||
self.port = port
|
||||
self._mode = mode
|
||||
self.timeout = timeout
|
||||
self._send_counter = 0
|
||||
|
||||
# TODO Rename "TcpClient" as some sort of generic socket
|
||||
self.conn = TcpClient(proxy=proxy)
|
||||
|
||||
if mode == 'tcp_full':
|
||||
setattr(self, 'send', self._send_tcp_full)
|
||||
setattr(self, 'recv', self._recv_tcp_full)
|
||||
|
||||
elif mode == 'tcp_abridged':
|
||||
setattr(self, 'send', self._send_abridged)
|
||||
setattr(self, 'recv', self._recv_abridged)
|
||||
|
||||
def connect(self):
|
||||
self._send_counter = 0
|
||||
self.conn.connect(self.ip, self.port,
|
||||
timeout=round(self.timeout.seconds))
|
||||
|
||||
if self._mode == 'tcp_abridged':
|
||||
self.conn.write(int.to_bytes(239, 1, 'little'))
|
||||
|
||||
def is_connected(self):
|
||||
return self.conn.connected
|
||||
|
||||
def close(self):
|
||||
self.conn.close()
|
||||
|
||||
def cancel_receive(self):
|
||||
"""Cancels (stops) trying to receive from the
|
||||
remote peer and raises a ReadCancelledError"""
|
||||
self.conn.cancel_read()
|
||||
|
||||
def get_client_delay(self):
|
||||
"""Gets the client read delay"""
|
||||
return self.conn.delay
|
||||
|
||||
# region Receive implementations
|
||||
|
||||
def recv(self, **kwargs):
|
||||
"""Receives and unpacks a message"""
|
||||
# TODO Don't ignore kwargs['timeout']?
|
||||
# Default implementation is just an error
|
||||
raise ValueError('Invalid connection mode specified: ' + self._mode)
|
||||
|
||||
def _recv_tcp_full(self, **kwargs):
|
||||
packet_length_bytes = self.conn.read(4, self.timeout)
|
||||
packet_length = int.from_bytes(packet_length_bytes, 'little')
|
||||
|
||||
seq_bytes = self.conn.read(4, self.timeout)
|
||||
seq = int.from_bytes(seq_bytes, 'little')
|
||||
|
||||
body = self.conn.read(packet_length - 12, self.timeout)
|
||||
checksum = int.from_bytes(self.conn.read(4, self.timeout), 'little')
|
||||
|
||||
valid_checksum = crc32(packet_length_bytes + seq_bytes + body)
|
||||
if checksum != valid_checksum:
|
||||
raise InvalidChecksumError(checksum, valid_checksum)
|
||||
|
||||
return body
|
||||
|
||||
def _recv_abridged(self, **kwargs):
|
||||
length = int.from_bytes(self.conn.read(1, self.timeout), 'little')
|
||||
if length >= 127:
|
||||
length = int.from_bytes(self.conn.read(3, self.timeout) + b'\0', 'little')
|
||||
|
||||
return self.conn.read(length << 2)
|
||||
|
||||
# endregion
|
||||
|
||||
# region Send implementations
|
||||
|
||||
def send(self, message):
|
||||
"""Encapsulates and sends the given message"""
|
||||
# Default implementation is just an error
|
||||
raise ValueError('Invalid connection mode specified: ' + self._mode)
|
||||
|
||||
def _send_tcp_full(self, message):
|
||||
# https://core.telegram.org/mtproto#tcp-transport
|
||||
# total length, sequence number, packet and checksum (CRC32)
|
||||
with BinaryWriter() as writer:
|
||||
writer.write_int(len(message) + 12)
|
||||
writer.write_int(self._send_counter)
|
||||
writer.write(message)
|
||||
writer.write_int(crc32(writer.get_bytes()), signed=False)
|
||||
self._send_counter += 1
|
||||
self.conn.write(writer.get_bytes())
|
||||
|
||||
def _send_abridged(self, message):
|
||||
with BinaryWriter() as writer:
|
||||
length = len(message) >> 2
|
||||
if length < 127:
|
||||
writer.write_byte(length)
|
||||
else:
|
||||
writer.write_byte(127)
|
||||
writer.write(int.to_bytes(length, 3, 'little'))
|
||||
writer.write(message)
|
||||
self.conn.write(writer.get_bytes())
|
||||
|
||||
# endregion
|
|
@ -6,17 +6,17 @@ from ..extensions import BinaryReader, BinaryWriter
|
|||
class MtProtoPlainSender:
|
||||
"""MTProto Mobile Protocol plain sender (https://core.telegram.org/mtproto/description#unencrypted-messages)"""
|
||||
|
||||
def __init__(self, transport):
|
||||
def __init__(self, connection):
|
||||
self._sequence = 0
|
||||
self._time_offset = 0
|
||||
self._last_msg_id = 0
|
||||
self._transport = transport
|
||||
self._connection = connection
|
||||
|
||||
def connect(self):
|
||||
self._transport.connect()
|
||||
self._connection.connect()
|
||||
|
||||
def disconnect(self):
|
||||
self._transport.close()
|
||||
self._connection.close()
|
||||
|
||||
def send(self, data):
|
||||
"""Sends a plain packet (auth_key_id = 0) containing the given message body (data)"""
|
||||
|
@ -27,11 +27,11 @@ class MtProtoPlainSender:
|
|||
writer.write(data)
|
||||
|
||||
packet = writer.get_bytes()
|
||||
self._transport.send(packet)
|
||||
self._connection.send(packet)
|
||||
|
||||
def receive(self):
|
||||
"""Receives a plain packet, returning the body of the response"""
|
||||
seq, body = self._transport.receive()
|
||||
body = self._connection.recv()
|
||||
with BinaryReader(body) as reader:
|
||||
reader.read_long() # auth_key_id
|
||||
reader.read_long() # msg_id
|
||||
|
|
|
@ -16,8 +16,8 @@ logging.getLogger(__name__).addHandler(logging.NullHandler())
|
|||
class MtProtoSender:
|
||||
"""MTProto Mobile Protocol sender (https://core.telegram.org/mtproto/description)"""
|
||||
|
||||
def __init__(self, transport, session):
|
||||
self.transport = transport
|
||||
def __init__(self, connection, session):
|
||||
self.connection = connection
|
||||
self.session = session
|
||||
self._logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -33,14 +33,14 @@ class MtProtoSender:
|
|||
|
||||
def connect(self):
|
||||
"""Connects to the server"""
|
||||
self.transport.connect()
|
||||
self.connection.connect()
|
||||
|
||||
def is_connected(self):
|
||||
return self.transport.is_connected()
|
||||
return self.connection.is_connected()
|
||||
|
||||
def disconnect(self):
|
||||
"""Disconnects from the server"""
|
||||
self.transport.close()
|
||||
self.connection.close()
|
||||
|
||||
# region Send and receive
|
||||
|
||||
|
@ -81,7 +81,7 @@ class MtProtoSender:
|
|||
the received data). This also restores the updates thread.
|
||||
|
||||
An optional named parameter 'timeout' can be specified if
|
||||
one desires to override 'self.transport.timeout'.
|
||||
one desires to override 'self.connection.timeout'.
|
||||
|
||||
If 'request' is None, a single item will be read into
|
||||
the 'updates' list (which cannot be None).
|
||||
|
@ -101,7 +101,7 @@ class MtProtoSender:
|
|||
while (request and not request.confirm_received) or \
|
||||
(not request and not updates):
|
||||
self._logger.debug('Trying to .receive() the request result...')
|
||||
seq, body = self.transport.receive(**kwargs)
|
||||
body = self.connection.recv(**kwargs)
|
||||
message, remote_msg_id, remote_seq = self._decode_msg(body)
|
||||
|
||||
with BinaryReader(message) as reader:
|
||||
|
@ -126,7 +126,7 @@ class MtProtoSender:
|
|||
def cancel_receive(self):
|
||||
"""Cancels any pending receive operation
|
||||
by raising a ReadCancelledError"""
|
||||
self.transport.cancel_receive()
|
||||
self.connection.cancel_receive()
|
||||
|
||||
# endregion
|
||||
|
||||
|
@ -159,7 +159,7 @@ class MtProtoSender:
|
|||
self.session.auth_key.key_id, signed=False)
|
||||
cipher_writer.write(msg_key)
|
||||
cipher_writer.write(cipher_text)
|
||||
self.transport.send(cipher_writer.get_bytes())
|
||||
self.connection.send(cipher_writer.get_bytes())
|
||||
|
||||
def _decode_msg(self, body):
|
||||
"""Decodes an received encrypted message body bytes"""
|
||||
|
|
|
@ -1,111 +0,0 @@
|
|||
from zlib import crc32
|
||||
from datetime import timedelta
|
||||
|
||||
from ..errors import InvalidChecksumError
|
||||
from ..extensions import TcpClient, TcpClientObfuscated
|
||||
from ..extensions import BinaryWriter
|
||||
|
||||
|
||||
class TcpTransport:
|
||||
def __init__(self, ip_address, port,
|
||||
proxy=None, timeout=timedelta(seconds=5)):
|
||||
self.ip = ip_address
|
||||
self.port = port
|
||||
self.tcp_client = TcpClientObfuscated(proxy)
|
||||
self.timeout = timeout
|
||||
self.send_counter = 0
|
||||
|
||||
def connect(self):
|
||||
"""Connects to the specified IP address and port"""
|
||||
self.send_counter = 0
|
||||
self.tcp_client.connect(self.ip, self.port,
|
||||
timeout=round(self.timeout.seconds))
|
||||
|
||||
def is_connected(self):
|
||||
return self.tcp_client.connected
|
||||
|
||||
# Original reference: https://core.telegram.org/mtproto#tcp-transport
|
||||
# The packets are encoded as:
|
||||
# total length, sequence number, packet and checksum (CRC32)
|
||||
def send(self, packet):
|
||||
"""Sends the given packet (bytes array) to the connected peer"""
|
||||
if not self.tcp_client.connected:
|
||||
raise ConnectionError('Client not connected to server.')
|
||||
|
||||
with BinaryWriter() as writer:
|
||||
if isinstance(self.tcp_client, TcpClient):
|
||||
# 12 = size_of (integer) * 3
|
||||
writer.write_int(len(packet) + 12)
|
||||
writer.write_int(self.send_counter)
|
||||
writer.write(packet)
|
||||
|
||||
crc = crc32(writer.get_bytes())
|
||||
writer.write_int(crc, signed=False)
|
||||
|
||||
self.send_counter += 1
|
||||
elif isinstance(self.tcp_client, TcpClientObfuscated):
|
||||
length = len(packet) >> 2
|
||||
if length < 127:
|
||||
writer.write_byte(length)
|
||||
else:
|
||||
writer.write_byte(127)
|
||||
writer.write(int.to_bytes(length, 3, 'little'))
|
||||
writer.write(packet)
|
||||
else:
|
||||
raise ValueError('Unknown client')
|
||||
|
||||
self.tcp_client.write(writer.get_bytes())
|
||||
|
||||
def receive(self, **kwargs):
|
||||
"""Receives a TCP message (tuple(sequence number, body)) from the
|
||||
connected peer.
|
||||
|
||||
If a named 'timeout' parameter is present, it will override
|
||||
'self.timeout', and this can be a 'timedelta' or 'None'.
|
||||
"""
|
||||
if isinstance(self.tcp_client, TcpClient):
|
||||
timeout = kwargs.get('timeout', self.timeout)
|
||||
|
||||
# First read everything we need
|
||||
packet_length_bytes = self.tcp_client.read(4, timeout)
|
||||
packet_length = int.from_bytes(packet_length_bytes, byteorder='little')
|
||||
|
||||
seq_bytes = self.tcp_client.read(4, timeout)
|
||||
seq = int.from_bytes(seq_bytes, byteorder='little')
|
||||
|
||||
body = self.tcp_client.read(packet_length - 12, timeout)
|
||||
|
||||
checksum = int.from_bytes(
|
||||
self.tcp_client.read(4, timeout), byteorder='little', signed=False)
|
||||
|
||||
# Then perform the checks
|
||||
rv = packet_length_bytes + seq_bytes + body
|
||||
valid_checksum = crc32(rv)
|
||||
|
||||
if checksum != valid_checksum:
|
||||
raise InvalidChecksumError(checksum, valid_checksum)
|
||||
|
||||
# If we passed the tests, we can then return a valid TCP message
|
||||
return seq, body
|
||||
elif isinstance(self.tcp_client, TcpClientObfuscated):
|
||||
packet_length = int.from_bytes(self.tcp_client.read(1), 'little')
|
||||
if packet_length < 127:
|
||||
return 0, self.tcp_client.read(packet_length << 2)
|
||||
else:
|
||||
plb = self.tcp_client.read(3)
|
||||
pl = int.from_bytes(plb + b'\0', 'little') << 2
|
||||
return 0, self.tcp_client.read(pl)
|
||||
else:
|
||||
raise ValueError('Unknown client')
|
||||
|
||||
def close(self):
|
||||
self.tcp_client.close()
|
||||
|
||||
def cancel_receive(self):
|
||||
"""Cancels (stops) trying to receive from the
|
||||
remote peer and raises a ReadCancelledError"""
|
||||
self.tcp_client.cancel_read()
|
||||
|
||||
def get_client_delay(self):
|
||||
"""Gets the client read delay"""
|
||||
return self.tcp_client.delay
|
|
@ -10,7 +10,7 @@ from . import helpers as utils
|
|||
from .errors import (
|
||||
RPCError, FloodWaitError, FileMigrateError, TypeNotFoundError
|
||||
)
|
||||
from .network import authenticator, MtProtoSender, TcpTransport
|
||||
from .network import authenticator, MtProtoSender, Connection
|
||||
from .utils import get_appropriated_part_size
|
||||
from .crypto import rsa, CdnDecrypter
|
||||
|
||||
|
@ -117,19 +117,19 @@ class TelegramBareClient:
|
|||
# If ping failed, ensure we're disconnected first
|
||||
self.disconnect()
|
||||
|
||||
transport = TcpTransport(self.session.server_address,
|
||||
self.session.port,
|
||||
proxy=self.proxy,
|
||||
timeout=self._timeout)
|
||||
connection = Connection(
|
||||
self.session.server_address, self.session.port,
|
||||
proxy=self.proxy, timeout=self._timeout
|
||||
)
|
||||
|
||||
try:
|
||||
if not self.session.auth_key:
|
||||
self.session.auth_key, self.session.time_offset = \
|
||||
authenticator.do_authentication(transport)
|
||||
authenticator.do_authentication(connection)
|
||||
|
||||
self.session.save()
|
||||
|
||||
self._sender = MtProtoSender(transport, self.session)
|
||||
self._sender = MtProtoSender(connection, self.session)
|
||||
self._sender.connect()
|
||||
|
||||
# Now it's time to send an InitConnectionRequest
|
||||
|
|
|
@ -5,7 +5,7 @@ import unittest
|
|||
|
||||
import telethon.network.authenticator as authenticator
|
||||
from telethon.extensions import TcpClient
|
||||
from telethon.network import TcpTransport
|
||||
from telethon.network import Connection
|
||||
|
||||
|
||||
def run_server_echo_thread(port):
|
||||
|
@ -38,6 +38,6 @@ class NetworkTests(unittest.TestCase):
|
|||
|
||||
@staticmethod
|
||||
def test_authenticator():
|
||||
transport = TcpTransport('149.154.167.91', 443)
|
||||
transport = Connection('149.154.167.91', 443)
|
||||
authenticator.do_authentication(transport)
|
||||
transport.close()
|
||||
|
|
Loading…
Reference in New Issue
Block a user