Let all connection modes implement the new Connection

This commit is contained in:
Lonami Exo 2018-09-27 19:22:35 +02:00
parent 096424ea78
commit 2fd51b8582
7 changed files with 95 additions and 202 deletions

View File

@ -1,74 +0,0 @@
"""
This module holds the abstract `Connection` class.
The `Connection.send` and `Connection.recv` methods need **not** to be
safe across several tasks and may use any amount of ``await`` keywords.
The code using these `Connection`'s should be responsible for using
an ``async with asyncio.Lock:`` block when calling said methods.
Said subclasses need not to worry about reconnecting either, and
should let the errors propagate instead.
"""
import abc
class Connection(abc.ABC):
"""
Represents an abstract connection for Telegram.
Subclasses should implement the actual protocol
being used when encoding/decoding messages.
"""
def __init__(self, *, loop, timeout, proxy=None):
"""
Initializes a new connection.
:param loop: the event loop to be used.
:param timeout: timeout to be used for all operations.
:param proxy: whether to use a proxy or not.
"""
self._loop = loop
self._proxy = proxy
self._timeout = timeout
@abc.abstractmethod
async def connect(self, ip, port):
raise NotImplementedError
@abc.abstractmethod
def get_timeout(self):
"""Returns the timeout used by the connection."""
raise NotImplementedError
@abc.abstractmethod
def is_connected(self):
"""
Determines whether the connection is alive or not.
:return: true if it's connected.
"""
raise NotImplementedError
@abc.abstractmethod
async def close(self):
"""Closes the connection."""
raise NotImplementedError
def clone(self):
"""Creates a copy of this Connection."""
return self.__class__(
loop=self._loop,
proxy=self._proxy,
timeout=self._timeout
)
@abc.abstractmethod
async def recv(self):
"""Receives and unpacks a message"""
raise NotImplementedError
@abc.abstractmethod
async def send(self, message):
"""Encapsulates and sends the given message"""
raise NotImplementedError

View File

@ -69,6 +69,7 @@ class Connection(abc.ABC):
self._send(await self._send_queue.get()) self._send(await self._send_queue.get())
await self._writer.drain() await self._writer.drain()
# TODO Handle IncompleteReadError and InvalidChecksumError
async def _recv_loop(self): async def _recv_loop(self):
""" """
This loop is constantly putting items on the queue as they're read. This loop is constantly putting items on the queue as they're read.

View File

@ -1,62 +1,38 @@
import errno import asyncio
import ssl
from .common import Connection from .connection import Connection
from ...extensions import TcpClient
class ConnectionHttp(Connection): class ConnectionHttp(Connection):
def __init__(self, *, loop, timeout, proxy=None): async def connect(self):
super().__init__(loop=loop, timeout=timeout, proxy=proxy) # TODO Test if the ssl part works or it needs to be as before:
self.conn = TcpClient( # dict(ssl_version=ssl.PROTOCOL_SSLv23, ciphers='ADH-AES256-SHA')
timeout=self._timeout, loop=self._loop, proxy=self._proxy, self._reader, self._writer = await asyncio.open_connection(
ssl=dict(ssl_version=ssl.PROTOCOL_SSLv23, ciphers='ADH-AES256-SHA') self._ip, self._port, loop=self._loop, ssl=True)
)
self.read = self.conn.read
self.write = self.conn.write
self._host = None
async def connect(self, ip, port): self._disconnected.clear()
self._host = '{}:{}'.format(ip, port) self._send_task = self._loop.create_task(self._send_loop())
try: self._recv_task = self._loop.create_task(self._send_loop())
await self.conn.connect(ip, port)
except OSError as e:
if e.errno == errno.EISCONN:
return # Already connected, no need to re-set everything up
else:
raise
def get_timeout(self): def _send(self, message):
return self.conn.timeout self._writer.write(
def is_connected(self):
return self.conn.is_connected
async def close(self):
self.conn.close()
async def recv(self):
while True:
line = await self._read_line()
if line.lower().startswith(b'content-length: '):
await self.read(2)
length = int(line[16:-2])
return await self.read(length)
async def _read_line(self):
newline = ord('\n')
line = await self.read(1)
while line[-1] != newline:
line += await self.read(1)
return line
async def send(self, message):
await self.write(
'POST /api HTTP/1.1\r\n' 'POST /api HTTP/1.1\r\n'
'Host: {}\r\n' 'Host: {}:{}\r\n'
'Content-Type: application/x-www-form-urlencoded\r\n' 'Content-Type: application/x-www-form-urlencoded\r\n'
'Connection: keep-alive\r\n' 'Connection: keep-alive\r\n'
'Keep-Alive: timeout=100000, max=10000000\r\n' 'Keep-Alive: timeout=100000, max=10000000\r\n'
'Content-Length: {}\r\n\r\n'.format(self._host, len(message)) 'Content-Length: {}\r\n\r\n'
.format(self._ip, self._port, len(message))
.encode('ascii') + message .encode('ascii') + message
) )
async def _recv(self):
while True:
line = await self._reader.readline()
if not line or line[-1] != b'\n':
raise asyncio.IncompleteReadError(line, None)
if line.lower().startswith(b'content-length: '):
await self._reader.readexactly(2)
length = int(line[16:-2])
return await self._reader.readexactly(length)

View File

@ -1,31 +1,43 @@
import struct import struct
from .tcpfull import ConnectionTcpFull from .connection import Connection
class ConnectionTcpAbridged(ConnectionTcpFull): class ConnectionTcpAbridged(Connection):
""" """
This is the mode with the lowest overhead, as it will This is the mode with the lowest overhead, as it will
only require 1 byte if the packet length is less than only require 1 byte if the packet length is less than
508 bytes (127 << 2, which is very common). 508 bytes (127 << 2, which is very common).
""" """
async def connect(self, ip, port): async def connect(self):
result = await super().connect(ip, port) await super().connect()
await self.conn.write(b'\xef') await self.send(b'\xef')
return result
async def recv(self): def _write(self, data):
length = struct.unpack('<B', await self.read(1))[0] """
if length >= 127: Define wrapper write methods for `TcpObfuscated` to override.
length = struct.unpack('<i', await self.read(3) + b'\0')[0] """
self._writer.write(data)
return await self.read(length << 2) async def _read(self, n):
"""
Define wrapper read methods for `TcpObfuscated` to override.
"""
return await self._reader.readexactly(n)
async def send(self, message): def _send(self, data):
length = len(message) >> 2 length = len(data) >> 2
if length < 127: if length < 127:
length = struct.pack('B', length) length = struct.pack('B', length)
else: else:
length = b'\x7f' + int.to_bytes(length, 3, 'little') length = b'\x7f' + int.to_bytes(length, 3, 'little')
await self.write(length + message) self._write(length + data)
async def _recv(self):
length = struct.unpack('<B', await self._read(1))[0]
if length >= 127:
length = struct.unpack(
'<i', await self._read(3) + b'\0')[0]
return await self._read(length << 2)

View File

@ -1,10 +1,8 @@
import errno
import struct import struct
from zlib import crc32 from zlib import crc32
from .common import Connection from .connection import Connection
from ...errors import InvalidChecksumError from ...errors import InvalidChecksumError
from ...extensions import TcpClient
class ConnectionTcpFull(Connection): class ConnectionTcpFull(Connection):
@ -12,39 +10,23 @@ class ConnectionTcpFull(Connection):
Default Telegram mode. Sends 12 additional bytes and Default Telegram mode. Sends 12 additional bytes and
needs to calculate the CRC value of the packet itself. needs to calculate the CRC value of the packet itself.
""" """
def __init__(self, *, loop, timeout, proxy=None): def __init__(self, ip, port, *, loop):
super().__init__(loop=loop, timeout=timeout, proxy=proxy) super().__init__(ip, port, loop=loop)
self._send_counter = 0
self.conn = TcpClient(
timeout=self._timeout, loop=self._loop, proxy=self._proxy
)
self.read = self.conn.read
self.write = self.conn.write
async def connect(self, ip, port):
try:
await self.conn.connect(ip, port)
except OSError as e:
if e.errno == errno.EISCONN:
return # Already connected, no need to re-set everything up
else:
raise
self._send_counter = 0 self._send_counter = 0
def get_timeout(self): def _send(self, data):
return self.conn.timeout # https://core.telegram.org/mtproto#tcp-transport
# total length, sequence number, packet and checksum (CRC32)
length = len(data) + 12
data = struct.pack('<ii', length, self._send_counter) + data
crc = struct.pack('<I', crc32(data))
self._send_counter += 1
self._writer.write(data + crc)
def is_connected(self): async def _recv(self):
return self.conn.is_connected packet_len_seq = await self._reader.readexactly(8) # 4 and 4
async def close(self):
self.conn.close()
async def recv(self):
packet_len_seq = await self.read(8) # 4 and 4
packet_len, seq = struct.unpack('<ii', packet_len_seq) packet_len, seq = struct.unpack('<ii', packet_len_seq)
body = await self.read(packet_len - 8) body = await self._reader.readexactly(packet_len - 8)
checksum = struct.unpack('<I', body[-4:])[0] checksum = struct.unpack('<I', body[-4:])[0]
body = body[:-4] body = body[:-4]
@ -53,12 +35,3 @@ class ConnectionTcpFull(Connection):
raise InvalidChecksumError(checksum, valid_checksum) raise InvalidChecksumError(checksum, valid_checksum)
return body return body
async def send(self, message):
# https://core.telegram.org/mtproto#tcp-transport
# total length, sequence number, packet and checksum (CRC32)
length = len(message) + 12
data = struct.pack('<ii', length, self._send_counter) + message
crc = struct.pack('<I', crc32(data))
self._send_counter += 1
await self.write(data + crc)

View File

@ -1,20 +1,20 @@
import struct import struct
from .tcpfull import ConnectionTcpFull from .connection import Connection
class ConnectionTcpIntermediate(ConnectionTcpFull): class ConnectionTcpIntermediate(Connection):
""" """
Intermediate mode between `ConnectionTcpFull` and `ConnectionTcpAbridged`. Intermediate mode between `ConnectionTcpFull` and `ConnectionTcpAbridged`.
Always sends 4 extra bytes for the packet length. Always sends 4 extra bytes for the packet length.
""" """
async def connect(self, ip, port): async def connect(self):
result = await super().connect(ip, port) await super().connect()
await self.conn.write(b'\xee\xee\xee\xee') await self.send(b'\xee\xee\xee\xee')
return result
async def recv(self): def _send(self, data):
return await self.read(struct.unpack('<i', await self.read(4))[0]) self._writer.write(struct.pack('<i', len(data)) + data)
async def send(self, message): async def _recv(self):
await self.write(struct.pack('<i', len(message)) + message) return await self._reader.readexactly(
struct.unpack('<i', await self._reader.readexactly(4))[0])

View File

@ -1,7 +1,7 @@
import os import os
from .connection import Connection
from .tcpabridged import ConnectionTcpAbridged from .tcpabridged import ConnectionTcpAbridged
from .tcpfull import ConnectionTcpFull
from ...crypto import AESModeCTR from ...crypto import AESModeCTR
@ -11,16 +11,22 @@ class ConnectionTcpObfuscated(ConnectionTcpAbridged):
every message with a randomly generated key using the every message with a randomly generated key using the
AES-CTR mode so the packets are harder to discern. AES-CTR mode so the packets are harder to discern.
""" """
def __init__(self, *, loop, timeout, proxy=None): def __init__(self, ip, port, *, loop):
super().__init__(loop=loop, timeout=timeout, proxy=proxy) super().__init__(ip, port, loop=loop)
self._aes_encrypt, self._aes_decrypt = None, None self._aes_encrypt = None
self.read = lambda s: self._aes_decrypt.encrypt(self.conn.read(s)) self._aes_decrypt = None
self.write = lambda d: self.conn.write(self._aes_encrypt.encrypt(d))
def _write(self, data):
self._writer.write(self._aes_encrypt.encrypt(data))
async def _read(self, n):
return self._aes_decrypt.encrypt(await self._reader.readexactly(n))
async def connect(self):
await Connection.connect(self)
async def connect(self, ip, port):
result = await ConnectionTcpFull.connect(self, ip, port)
# Obfuscated messages secrets cannot start with any of these # Obfuscated messages secrets cannot start with any of these
keywords = (b'PVrG', b'GET ', b'POST', b'\xee' * 4) keywords = (b'PVrG', b'GET ', b'POST', b'\xee\xee\xee\xee')
while True: while True:
random = os.urandom(64) random = os.urandom(64)
if (random[0] != b'\xef' and if (random[0] != b'\xef' and
@ -28,11 +34,11 @@ class ConnectionTcpObfuscated(ConnectionTcpAbridged):
random[4:4] != b'\0\0\0\0'): random[4:4] != b'\0\0\0\0'):
break break
random = list(random) random = bytearray(random)
random[56] = random[57] = random[58] = random[59] = 0xef random[56] = random[57] = random[58] = random[59] = 0xef
random_reversed = random[55:7:-1] # Reversed (8, len=48) random_reversed = random[55:7:-1] # Reversed (8, len=48)
# encryption has "continuous buffer" enabled # Encryption has "continuous buffer" enabled
encrypt_key = bytes(random[8:40]) encrypt_key = bytes(random[8:40])
encrypt_iv = bytes(random[40:56]) encrypt_iv = bytes(random[40:56])
decrypt_key = bytes(random_reversed[:32]) decrypt_key = bytes(random_reversed[:32])
@ -42,5 +48,4 @@ class ConnectionTcpObfuscated(ConnectionTcpAbridged):
self._aes_decrypt = AESModeCTR(decrypt_key, decrypt_iv) self._aes_decrypt = AESModeCTR(decrypt_key, decrypt_iv)
random[56:64] = self._aes_encrypt.encrypt(bytes(random))[56:64] random[56:64] = self._aes_encrypt.encrypt(bytes(random))[56:64]
await self.conn.write(bytes(random)) await self._send(random)
return result