mirror of
https://github.com/LonamiWebs/Telethon.git
synced 2025-02-03 05:04:33 +03:00
Create a centralized Connection class, replaces TcpTransport (#112)
This commit is contained in:
parent
bc72e52834
commit
fa22a3f848
|
@ -1,4 +1,4 @@
|
||||||
from .mtproto_plain_sender import MtProtoPlainSender
|
from .mtproto_plain_sender import MtProtoPlainSender
|
||||||
from .authenticator import do_authentication
|
from .authenticator import do_authentication
|
||||||
from .mtproto_sender import MtProtoSender
|
from .mtproto_sender import MtProtoSender
|
||||||
from .tcp_transport import TcpTransport
|
from .connection import Connection
|
||||||
|
|
|
@ -9,12 +9,12 @@ from ..network import MtProtoPlainSender
|
||||||
from ..extensions import BinaryReader, BinaryWriter
|
from ..extensions import BinaryReader, BinaryWriter
|
||||||
|
|
||||||
|
|
||||||
def do_authentication(transport):
|
def do_authentication(connection):
|
||||||
"""Executes the authentication process with the Telegram servers.
|
"""Executes the authentication process with the Telegram servers.
|
||||||
If no error is raised, returns both the authorization key and the
|
If no error is raised, returns both the authorization key and the
|
||||||
time offset.
|
time offset.
|
||||||
"""
|
"""
|
||||||
sender = MtProtoPlainSender(transport)
|
sender = MtProtoPlainSender(connection)
|
||||||
sender.connect()
|
sender.connect()
|
||||||
|
|
||||||
# Step 1 sending: PQ Request
|
# Step 1 sending: PQ Request
|
||||||
|
|
116
telethon/network/connection.py
Normal file
116
telethon/network/connection.py
Normal file
|
@ -0,0 +1,116 @@
|
||||||
|
from datetime import timedelta
|
||||||
|
from zlib import crc32
|
||||||
|
|
||||||
|
from ..extensions import BinaryWriter, TcpClient
|
||||||
|
from ..errors import InvalidChecksumError
|
||||||
|
|
||||||
|
|
||||||
|
class Connection:
|
||||||
|
def __init__(self, ip, port, mode='tcp_abridged',
|
||||||
|
proxy=None, timeout=timedelta(seconds=5)):
|
||||||
|
"""Represents an abstract connection (TCP, TCP abridged...).
|
||||||
|
'mode' may be any of 'tcp_full', 'tcp_abridged'
|
||||||
|
"""
|
||||||
|
self.ip = ip
|
||||||
|
self.port = port
|
||||||
|
self._mode = mode
|
||||||
|
self.timeout = timeout
|
||||||
|
self._send_counter = 0
|
||||||
|
|
||||||
|
# TODO Rename "TcpClient" as some sort of generic socket
|
||||||
|
self.conn = TcpClient(proxy=proxy)
|
||||||
|
|
||||||
|
if mode == 'tcp_full':
|
||||||
|
setattr(self, 'send', self._send_tcp_full)
|
||||||
|
setattr(self, 'recv', self._recv_tcp_full)
|
||||||
|
|
||||||
|
elif mode == 'tcp_abridged':
|
||||||
|
setattr(self, 'send', self._send_abridged)
|
||||||
|
setattr(self, 'recv', self._recv_abridged)
|
||||||
|
|
||||||
|
def connect(self):
|
||||||
|
self._send_counter = 0
|
||||||
|
self.conn.connect(self.ip, self.port,
|
||||||
|
timeout=round(self.timeout.seconds))
|
||||||
|
|
||||||
|
if self._mode == 'tcp_abridged':
|
||||||
|
self.conn.write(int.to_bytes(239, 1, 'little'))
|
||||||
|
|
||||||
|
def is_connected(self):
|
||||||
|
return self.conn.connected
|
||||||
|
|
||||||
|
def close(self):
|
||||||
|
self.conn.close()
|
||||||
|
|
||||||
|
def cancel_receive(self):
|
||||||
|
"""Cancels (stops) trying to receive from the
|
||||||
|
remote peer and raises a ReadCancelledError"""
|
||||||
|
self.conn.cancel_read()
|
||||||
|
|
||||||
|
def get_client_delay(self):
|
||||||
|
"""Gets the client read delay"""
|
||||||
|
return self.conn.delay
|
||||||
|
|
||||||
|
# region Receive implementations
|
||||||
|
|
||||||
|
def recv(self, **kwargs):
|
||||||
|
"""Receives and unpacks a message"""
|
||||||
|
# TODO Don't ignore kwargs['timeout']?
|
||||||
|
# Default implementation is just an error
|
||||||
|
raise ValueError('Invalid connection mode specified: ' + self._mode)
|
||||||
|
|
||||||
|
def _recv_tcp_full(self, **kwargs):
|
||||||
|
packet_length_bytes = self.conn.read(4, self.timeout)
|
||||||
|
packet_length = int.from_bytes(packet_length_bytes, 'little')
|
||||||
|
|
||||||
|
seq_bytes = self.conn.read(4, self.timeout)
|
||||||
|
seq = int.from_bytes(seq_bytes, 'little')
|
||||||
|
|
||||||
|
body = self.conn.read(packet_length - 12, self.timeout)
|
||||||
|
checksum = int.from_bytes(self.conn.read(4, self.timeout), 'little')
|
||||||
|
|
||||||
|
valid_checksum = crc32(packet_length_bytes + seq_bytes + body)
|
||||||
|
if checksum != valid_checksum:
|
||||||
|
raise InvalidChecksumError(checksum, valid_checksum)
|
||||||
|
|
||||||
|
return body
|
||||||
|
|
||||||
|
def _recv_abridged(self, **kwargs):
|
||||||
|
length = int.from_bytes(self.conn.read(1, self.timeout), 'little')
|
||||||
|
if length >= 127:
|
||||||
|
length = int.from_bytes(self.conn.read(3, self.timeout) + b'\0', 'little')
|
||||||
|
|
||||||
|
return self.conn.read(length << 2)
|
||||||
|
|
||||||
|
# endregion
|
||||||
|
|
||||||
|
# region Send implementations
|
||||||
|
|
||||||
|
def send(self, message):
|
||||||
|
"""Encapsulates and sends the given message"""
|
||||||
|
# Default implementation is just an error
|
||||||
|
raise ValueError('Invalid connection mode specified: ' + self._mode)
|
||||||
|
|
||||||
|
def _send_tcp_full(self, message):
|
||||||
|
# https://core.telegram.org/mtproto#tcp-transport
|
||||||
|
# total length, sequence number, packet and checksum (CRC32)
|
||||||
|
with BinaryWriter() as writer:
|
||||||
|
writer.write_int(len(message) + 12)
|
||||||
|
writer.write_int(self._send_counter)
|
||||||
|
writer.write(message)
|
||||||
|
writer.write_int(crc32(writer.get_bytes()), signed=False)
|
||||||
|
self._send_counter += 1
|
||||||
|
self.conn.write(writer.get_bytes())
|
||||||
|
|
||||||
|
def _send_abridged(self, message):
|
||||||
|
with BinaryWriter() as writer:
|
||||||
|
length = len(message) >> 2
|
||||||
|
if length < 127:
|
||||||
|
writer.write_byte(length)
|
||||||
|
else:
|
||||||
|
writer.write_byte(127)
|
||||||
|
writer.write(int.to_bytes(length, 3, 'little'))
|
||||||
|
writer.write(message)
|
||||||
|
self.conn.write(writer.get_bytes())
|
||||||
|
|
||||||
|
# endregion
|
|
@ -6,17 +6,17 @@ from ..extensions import BinaryReader, BinaryWriter
|
||||||
class MtProtoPlainSender:
|
class MtProtoPlainSender:
|
||||||
"""MTProto Mobile Protocol plain sender (https://core.telegram.org/mtproto/description#unencrypted-messages)"""
|
"""MTProto Mobile Protocol plain sender (https://core.telegram.org/mtproto/description#unencrypted-messages)"""
|
||||||
|
|
||||||
def __init__(self, transport):
|
def __init__(self, connection):
|
||||||
self._sequence = 0
|
self._sequence = 0
|
||||||
self._time_offset = 0
|
self._time_offset = 0
|
||||||
self._last_msg_id = 0
|
self._last_msg_id = 0
|
||||||
self._transport = transport
|
self._connection = connection
|
||||||
|
|
||||||
def connect(self):
|
def connect(self):
|
||||||
self._transport.connect()
|
self._connection.connect()
|
||||||
|
|
||||||
def disconnect(self):
|
def disconnect(self):
|
||||||
self._transport.close()
|
self._connection.close()
|
||||||
|
|
||||||
def send(self, data):
|
def send(self, data):
|
||||||
"""Sends a plain packet (auth_key_id = 0) containing the given message body (data)"""
|
"""Sends a plain packet (auth_key_id = 0) containing the given message body (data)"""
|
||||||
|
@ -27,11 +27,11 @@ class MtProtoPlainSender:
|
||||||
writer.write(data)
|
writer.write(data)
|
||||||
|
|
||||||
packet = writer.get_bytes()
|
packet = writer.get_bytes()
|
||||||
self._transport.send(packet)
|
self._connection.send(packet)
|
||||||
|
|
||||||
def receive(self):
|
def receive(self):
|
||||||
"""Receives a plain packet, returning the body of the response"""
|
"""Receives a plain packet, returning the body of the response"""
|
||||||
seq, body = self._transport.receive()
|
body = self._connection.recv()
|
||||||
with BinaryReader(body) as reader:
|
with BinaryReader(body) as reader:
|
||||||
reader.read_long() # auth_key_id
|
reader.read_long() # auth_key_id
|
||||||
reader.read_long() # msg_id
|
reader.read_long() # msg_id
|
||||||
|
|
|
@ -16,8 +16,8 @@ logging.getLogger(__name__).addHandler(logging.NullHandler())
|
||||||
class MtProtoSender:
|
class MtProtoSender:
|
||||||
"""MTProto Mobile Protocol sender (https://core.telegram.org/mtproto/description)"""
|
"""MTProto Mobile Protocol sender (https://core.telegram.org/mtproto/description)"""
|
||||||
|
|
||||||
def __init__(self, transport, session):
|
def __init__(self, connection, session):
|
||||||
self.transport = transport
|
self.connection = connection
|
||||||
self.session = session
|
self.session = session
|
||||||
self._logger = logging.getLogger(__name__)
|
self._logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
@ -33,14 +33,14 @@ class MtProtoSender:
|
||||||
|
|
||||||
def connect(self):
|
def connect(self):
|
||||||
"""Connects to the server"""
|
"""Connects to the server"""
|
||||||
self.transport.connect()
|
self.connection.connect()
|
||||||
|
|
||||||
def is_connected(self):
|
def is_connected(self):
|
||||||
return self.transport.is_connected()
|
return self.connection.is_connected()
|
||||||
|
|
||||||
def disconnect(self):
|
def disconnect(self):
|
||||||
"""Disconnects from the server"""
|
"""Disconnects from the server"""
|
||||||
self.transport.close()
|
self.connection.close()
|
||||||
|
|
||||||
# region Send and receive
|
# region Send and receive
|
||||||
|
|
||||||
|
@ -81,7 +81,7 @@ class MtProtoSender:
|
||||||
the received data). This also restores the updates thread.
|
the received data). This also restores the updates thread.
|
||||||
|
|
||||||
An optional named parameter 'timeout' can be specified if
|
An optional named parameter 'timeout' can be specified if
|
||||||
one desires to override 'self.transport.timeout'.
|
one desires to override 'self.connection.timeout'.
|
||||||
|
|
||||||
If 'request' is None, a single item will be read into
|
If 'request' is None, a single item will be read into
|
||||||
the 'updates' list (which cannot be None).
|
the 'updates' list (which cannot be None).
|
||||||
|
@ -101,7 +101,7 @@ class MtProtoSender:
|
||||||
while (request and not request.confirm_received) or \
|
while (request and not request.confirm_received) or \
|
||||||
(not request and not updates):
|
(not request and not updates):
|
||||||
self._logger.debug('Trying to .receive() the request result...')
|
self._logger.debug('Trying to .receive() the request result...')
|
||||||
seq, body = self.transport.receive(**kwargs)
|
body = self.connection.recv(**kwargs)
|
||||||
message, remote_msg_id, remote_seq = self._decode_msg(body)
|
message, remote_msg_id, remote_seq = self._decode_msg(body)
|
||||||
|
|
||||||
with BinaryReader(message) as reader:
|
with BinaryReader(message) as reader:
|
||||||
|
@ -126,7 +126,7 @@ class MtProtoSender:
|
||||||
def cancel_receive(self):
|
def cancel_receive(self):
|
||||||
"""Cancels any pending receive operation
|
"""Cancels any pending receive operation
|
||||||
by raising a ReadCancelledError"""
|
by raising a ReadCancelledError"""
|
||||||
self.transport.cancel_receive()
|
self.connection.cancel_receive()
|
||||||
|
|
||||||
# endregion
|
# endregion
|
||||||
|
|
||||||
|
@ -159,7 +159,7 @@ class MtProtoSender:
|
||||||
self.session.auth_key.key_id, signed=False)
|
self.session.auth_key.key_id, signed=False)
|
||||||
cipher_writer.write(msg_key)
|
cipher_writer.write(msg_key)
|
||||||
cipher_writer.write(cipher_text)
|
cipher_writer.write(cipher_text)
|
||||||
self.transport.send(cipher_writer.get_bytes())
|
self.connection.send(cipher_writer.get_bytes())
|
||||||
|
|
||||||
def _decode_msg(self, body):
|
def _decode_msg(self, body):
|
||||||
"""Decodes an received encrypted message body bytes"""
|
"""Decodes an received encrypted message body bytes"""
|
||||||
|
|
|
@ -1,111 +0,0 @@
|
||||||
from zlib import crc32
|
|
||||||
from datetime import timedelta
|
|
||||||
|
|
||||||
from ..errors import InvalidChecksumError
|
|
||||||
from ..extensions import TcpClient, TcpClientObfuscated
|
|
||||||
from ..extensions import BinaryWriter
|
|
||||||
|
|
||||||
|
|
||||||
class TcpTransport:
|
|
||||||
def __init__(self, ip_address, port,
|
|
||||||
proxy=None, timeout=timedelta(seconds=5)):
|
|
||||||
self.ip = ip_address
|
|
||||||
self.port = port
|
|
||||||
self.tcp_client = TcpClientObfuscated(proxy)
|
|
||||||
self.timeout = timeout
|
|
||||||
self.send_counter = 0
|
|
||||||
|
|
||||||
def connect(self):
|
|
||||||
"""Connects to the specified IP address and port"""
|
|
||||||
self.send_counter = 0
|
|
||||||
self.tcp_client.connect(self.ip, self.port,
|
|
||||||
timeout=round(self.timeout.seconds))
|
|
||||||
|
|
||||||
def is_connected(self):
|
|
||||||
return self.tcp_client.connected
|
|
||||||
|
|
||||||
# Original reference: https://core.telegram.org/mtproto#tcp-transport
|
|
||||||
# The packets are encoded as:
|
|
||||||
# total length, sequence number, packet and checksum (CRC32)
|
|
||||||
def send(self, packet):
|
|
||||||
"""Sends the given packet (bytes array) to the connected peer"""
|
|
||||||
if not self.tcp_client.connected:
|
|
||||||
raise ConnectionError('Client not connected to server.')
|
|
||||||
|
|
||||||
with BinaryWriter() as writer:
|
|
||||||
if isinstance(self.tcp_client, TcpClient):
|
|
||||||
# 12 = size_of (integer) * 3
|
|
||||||
writer.write_int(len(packet) + 12)
|
|
||||||
writer.write_int(self.send_counter)
|
|
||||||
writer.write(packet)
|
|
||||||
|
|
||||||
crc = crc32(writer.get_bytes())
|
|
||||||
writer.write_int(crc, signed=False)
|
|
||||||
|
|
||||||
self.send_counter += 1
|
|
||||||
elif isinstance(self.tcp_client, TcpClientObfuscated):
|
|
||||||
length = len(packet) >> 2
|
|
||||||
if length < 127:
|
|
||||||
writer.write_byte(length)
|
|
||||||
else:
|
|
||||||
writer.write_byte(127)
|
|
||||||
writer.write(int.to_bytes(length, 3, 'little'))
|
|
||||||
writer.write(packet)
|
|
||||||
else:
|
|
||||||
raise ValueError('Unknown client')
|
|
||||||
|
|
||||||
self.tcp_client.write(writer.get_bytes())
|
|
||||||
|
|
||||||
def receive(self, **kwargs):
|
|
||||||
"""Receives a TCP message (tuple(sequence number, body)) from the
|
|
||||||
connected peer.
|
|
||||||
|
|
||||||
If a named 'timeout' parameter is present, it will override
|
|
||||||
'self.timeout', and this can be a 'timedelta' or 'None'.
|
|
||||||
"""
|
|
||||||
if isinstance(self.tcp_client, TcpClient):
|
|
||||||
timeout = kwargs.get('timeout', self.timeout)
|
|
||||||
|
|
||||||
# First read everything we need
|
|
||||||
packet_length_bytes = self.tcp_client.read(4, timeout)
|
|
||||||
packet_length = int.from_bytes(packet_length_bytes, byteorder='little')
|
|
||||||
|
|
||||||
seq_bytes = self.tcp_client.read(4, timeout)
|
|
||||||
seq = int.from_bytes(seq_bytes, byteorder='little')
|
|
||||||
|
|
||||||
body = self.tcp_client.read(packet_length - 12, timeout)
|
|
||||||
|
|
||||||
checksum = int.from_bytes(
|
|
||||||
self.tcp_client.read(4, timeout), byteorder='little', signed=False)
|
|
||||||
|
|
||||||
# Then perform the checks
|
|
||||||
rv = packet_length_bytes + seq_bytes + body
|
|
||||||
valid_checksum = crc32(rv)
|
|
||||||
|
|
||||||
if checksum != valid_checksum:
|
|
||||||
raise InvalidChecksumError(checksum, valid_checksum)
|
|
||||||
|
|
||||||
# If we passed the tests, we can then return a valid TCP message
|
|
||||||
return seq, body
|
|
||||||
elif isinstance(self.tcp_client, TcpClientObfuscated):
|
|
||||||
packet_length = int.from_bytes(self.tcp_client.read(1), 'little')
|
|
||||||
if packet_length < 127:
|
|
||||||
return 0, self.tcp_client.read(packet_length << 2)
|
|
||||||
else:
|
|
||||||
plb = self.tcp_client.read(3)
|
|
||||||
pl = int.from_bytes(plb + b'\0', 'little') << 2
|
|
||||||
return 0, self.tcp_client.read(pl)
|
|
||||||
else:
|
|
||||||
raise ValueError('Unknown client')
|
|
||||||
|
|
||||||
def close(self):
|
|
||||||
self.tcp_client.close()
|
|
||||||
|
|
||||||
def cancel_receive(self):
|
|
||||||
"""Cancels (stops) trying to receive from the
|
|
||||||
remote peer and raises a ReadCancelledError"""
|
|
||||||
self.tcp_client.cancel_read()
|
|
||||||
|
|
||||||
def get_client_delay(self):
|
|
||||||
"""Gets the client read delay"""
|
|
||||||
return self.tcp_client.delay
|
|
|
@ -10,7 +10,7 @@ from . import helpers as utils
|
||||||
from .errors import (
|
from .errors import (
|
||||||
RPCError, FloodWaitError, FileMigrateError, TypeNotFoundError
|
RPCError, FloodWaitError, FileMigrateError, TypeNotFoundError
|
||||||
)
|
)
|
||||||
from .network import authenticator, MtProtoSender, TcpTransport
|
from .network import authenticator, MtProtoSender, Connection
|
||||||
from .utils import get_appropriated_part_size
|
from .utils import get_appropriated_part_size
|
||||||
from .crypto import rsa, CdnDecrypter
|
from .crypto import rsa, CdnDecrypter
|
||||||
|
|
||||||
|
@ -117,19 +117,19 @@ class TelegramBareClient:
|
||||||
# If ping failed, ensure we're disconnected first
|
# If ping failed, ensure we're disconnected first
|
||||||
self.disconnect()
|
self.disconnect()
|
||||||
|
|
||||||
transport = TcpTransport(self.session.server_address,
|
connection = Connection(
|
||||||
self.session.port,
|
self.session.server_address, self.session.port,
|
||||||
proxy=self.proxy,
|
proxy=self.proxy, timeout=self._timeout
|
||||||
timeout=self._timeout)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if not self.session.auth_key:
|
if not self.session.auth_key:
|
||||||
self.session.auth_key, self.session.time_offset = \
|
self.session.auth_key, self.session.time_offset = \
|
||||||
authenticator.do_authentication(transport)
|
authenticator.do_authentication(connection)
|
||||||
|
|
||||||
self.session.save()
|
self.session.save()
|
||||||
|
|
||||||
self._sender = MtProtoSender(transport, self.session)
|
self._sender = MtProtoSender(connection, self.session)
|
||||||
self._sender.connect()
|
self._sender.connect()
|
||||||
|
|
||||||
# Now it's time to send an InitConnectionRequest
|
# Now it's time to send an InitConnectionRequest
|
||||||
|
|
|
@ -5,7 +5,7 @@ import unittest
|
||||||
|
|
||||||
import telethon.network.authenticator as authenticator
|
import telethon.network.authenticator as authenticator
|
||||||
from telethon.extensions import TcpClient
|
from telethon.extensions import TcpClient
|
||||||
from telethon.network import TcpTransport
|
from telethon.network import Connection
|
||||||
|
|
||||||
|
|
||||||
def run_server_echo_thread(port):
|
def run_server_echo_thread(port):
|
||||||
|
@ -38,6 +38,6 @@ class NetworkTests(unittest.TestCase):
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def test_authenticator():
|
def test_authenticator():
|
||||||
transport = TcpTransport('149.154.167.91', 443)
|
transport = Connection('149.154.167.91', 443)
|
||||||
authenticator.do_authentication(transport)
|
authenticator.do_authentication(transport)
|
||||||
transport.close()
|
transport.close()
|
||||||
|
|
Loading…
Reference in New Issue
Block a user