Let all connection modes implement the new Connection

This commit is contained in:
Lonami Exo 2018-09-27 19:22:35 +02:00
parent 096424ea78
commit 2fd51b8582
7 changed files with 95 additions and 202 deletions

View File

@ -1,74 +0,0 @@
"""
This module holds the abstract `Connection` class.
The `Connection.send` and `Connection.recv` methods need **not** to be
safe across several tasks and may use any amount of ``await`` keywords.
The code using these `Connection`'s should be responsible for using
an ``async with asyncio.Lock:`` block when calling said methods.
Said subclasses need not to worry about reconnecting either, and
should let the errors propagate instead.
"""
import abc
class Connection(abc.ABC):
"""
Represents an abstract connection for Telegram.
Subclasses should implement the actual protocol
being used when encoding/decoding messages.
"""
def __init__(self, *, loop, timeout, proxy=None):
"""
Initializes a new connection.
:param loop: the event loop to be used.
:param timeout: timeout to be used for all operations.
:param proxy: whether to use a proxy or not.
"""
self._loop = loop
self._proxy = proxy
self._timeout = timeout
@abc.abstractmethod
async def connect(self, ip, port):
raise NotImplementedError
@abc.abstractmethod
def get_timeout(self):
"""Returns the timeout used by the connection."""
raise NotImplementedError
@abc.abstractmethod
def is_connected(self):
"""
Determines whether the connection is alive or not.
:return: true if it's connected.
"""
raise NotImplementedError
@abc.abstractmethod
async def close(self):
"""Closes the connection."""
raise NotImplementedError
def clone(self):
"""Creates a copy of this Connection."""
return self.__class__(
loop=self._loop,
proxy=self._proxy,
timeout=self._timeout
)
@abc.abstractmethod
async def recv(self):
"""Receives and unpacks a message"""
raise NotImplementedError
@abc.abstractmethod
async def send(self, message):
"""Encapsulates and sends the given message"""
raise NotImplementedError

View File

@ -69,6 +69,7 @@ class Connection(abc.ABC):
self._send(await self._send_queue.get())
await self._writer.drain()
# TODO Handle IncompleteReadError and InvalidChecksumError
async def _recv_loop(self):
"""
This loop is constantly putting items on the queue as they're read.

View File

@ -1,62 +1,38 @@
import errno
import ssl
import asyncio
from .common import Connection
from ...extensions import TcpClient
from .connection import Connection
class ConnectionHttp(Connection):
def __init__(self, *, loop, timeout, proxy=None):
super().__init__(loop=loop, timeout=timeout, proxy=proxy)
self.conn = TcpClient(
timeout=self._timeout, loop=self._loop, proxy=self._proxy,
ssl=dict(ssl_version=ssl.PROTOCOL_SSLv23, ciphers='ADH-AES256-SHA')
)
self.read = self.conn.read
self.write = self.conn.write
self._host = None
async def connect(self):
# TODO Test if the ssl part works or it needs to be as before:
# dict(ssl_version=ssl.PROTOCOL_SSLv23, ciphers='ADH-AES256-SHA')
self._reader, self._writer = await asyncio.open_connection(
self._ip, self._port, loop=self._loop, ssl=True)
async def connect(self, ip, port):
self._host = '{}:{}'.format(ip, port)
try:
await self.conn.connect(ip, port)
except OSError as e:
if e.errno == errno.EISCONN:
return # Already connected, no need to re-set everything up
else:
raise
self._disconnected.clear()
self._send_task = self._loop.create_task(self._send_loop())
self._recv_task = self._loop.create_task(self._send_loop())
def get_timeout(self):
return self.conn.timeout
def is_connected(self):
return self.conn.is_connected
async def close(self):
self.conn.close()
async def recv(self):
while True:
line = await self._read_line()
if line.lower().startswith(b'content-length: '):
await self.read(2)
length = int(line[16:-2])
return await self.read(length)
async def _read_line(self):
newline = ord('\n')
line = await self.read(1)
while line[-1] != newline:
line += await self.read(1)
return line
async def send(self, message):
await self.write(
def _send(self, message):
self._writer.write(
'POST /api HTTP/1.1\r\n'
'Host: {}\r\n'
'Host: {}:{}\r\n'
'Content-Type: application/x-www-form-urlencoded\r\n'
'Connection: keep-alive\r\n'
'Keep-Alive: timeout=100000, max=10000000\r\n'
'Content-Length: {}\r\n\r\n'.format(self._host, len(message))
'Content-Length: {}\r\n\r\n'
.format(self._ip, self._port, len(message))
.encode('ascii') + message
)
async def _recv(self):
while True:
line = await self._reader.readline()
if not line or line[-1] != b'\n':
raise asyncio.IncompleteReadError(line, None)
if line.lower().startswith(b'content-length: '):
await self._reader.readexactly(2)
length = int(line[16:-2])
return await self._reader.readexactly(length)

View File

@ -1,31 +1,43 @@
import struct
from .tcpfull import ConnectionTcpFull
from .connection import Connection
class ConnectionTcpAbridged(ConnectionTcpFull):
class ConnectionTcpAbridged(Connection):
"""
This is the mode with the lowest overhead, as it will
only require 1 byte if the packet length is less than
508 bytes (127 << 2, which is very common).
"""
async def connect(self, ip, port):
result = await super().connect(ip, port)
await self.conn.write(b'\xef')
return result
async def connect(self):
await super().connect()
await self.send(b'\xef')
async def recv(self):
length = struct.unpack('<B', await self.read(1))[0]
if length >= 127:
length = struct.unpack('<i', await self.read(3) + b'\0')[0]
def _write(self, data):
"""
Define wrapper write methods for `TcpObfuscated` to override.
"""
self._writer.write(data)
return await self.read(length << 2)
async def _read(self, n):
"""
Define wrapper read methods for `TcpObfuscated` to override.
"""
return await self._reader.readexactly(n)
async def send(self, message):
length = len(message) >> 2
def _send(self, data):
length = len(data) >> 2
if length < 127:
length = struct.pack('B', length)
else:
length = b'\x7f' + int.to_bytes(length, 3, 'little')
await self.write(length + message)
self._write(length + data)
async def _recv(self):
length = struct.unpack('<B', await self._read(1))[0]
if length >= 127:
length = struct.unpack(
'<i', await self._read(3) + b'\0')[0]
return await self._read(length << 2)

View File

@ -1,10 +1,8 @@
import errno
import struct
from zlib import crc32
from .common import Connection
from .connection import Connection
from ...errors import InvalidChecksumError
from ...extensions import TcpClient
class ConnectionTcpFull(Connection):
@ -12,39 +10,23 @@ class ConnectionTcpFull(Connection):
Default Telegram mode. Sends 12 additional bytes and
needs to calculate the CRC value of the packet itself.
"""
def __init__(self, *, loop, timeout, proxy=None):
super().__init__(loop=loop, timeout=timeout, proxy=proxy)
self._send_counter = 0
self.conn = TcpClient(
timeout=self._timeout, loop=self._loop, proxy=self._proxy
)
self.read = self.conn.read
self.write = self.conn.write
async def connect(self, ip, port):
try:
await self.conn.connect(ip, port)
except OSError as e:
if e.errno == errno.EISCONN:
return # Already connected, no need to re-set everything up
else:
raise
def __init__(self, ip, port, *, loop):
super().__init__(ip, port, loop=loop)
self._send_counter = 0
def get_timeout(self):
return self.conn.timeout
def _send(self, data):
# https://core.telegram.org/mtproto#tcp-transport
# total length, sequence number, packet and checksum (CRC32)
length = len(data) + 12
data = struct.pack('<ii', length, self._send_counter) + data
crc = struct.pack('<I', crc32(data))
self._send_counter += 1
self._writer.write(data + crc)
def is_connected(self):
return self.conn.is_connected
async def close(self):
self.conn.close()
async def recv(self):
packet_len_seq = await self.read(8) # 4 and 4
async def _recv(self):
packet_len_seq = await self._reader.readexactly(8) # 4 and 4
packet_len, seq = struct.unpack('<ii', packet_len_seq)
body = await self.read(packet_len - 8)
body = await self._reader.readexactly(packet_len - 8)
checksum = struct.unpack('<I', body[-4:])[0]
body = body[:-4]
@ -53,12 +35,3 @@ class ConnectionTcpFull(Connection):
raise InvalidChecksumError(checksum, valid_checksum)
return body
async def send(self, message):
# https://core.telegram.org/mtproto#tcp-transport
# total length, sequence number, packet and checksum (CRC32)
length = len(message) + 12
data = struct.pack('<ii', length, self._send_counter) + message
crc = struct.pack('<I', crc32(data))
self._send_counter += 1
await self.write(data + crc)

View File

@ -1,20 +1,20 @@
import struct
from .tcpfull import ConnectionTcpFull
from .connection import Connection
class ConnectionTcpIntermediate(ConnectionTcpFull):
class ConnectionTcpIntermediate(Connection):
"""
Intermediate mode between `ConnectionTcpFull` and `ConnectionTcpAbridged`.
Always sends 4 extra bytes for the packet length.
"""
async def connect(self, ip, port):
result = await super().connect(ip, port)
await self.conn.write(b'\xee\xee\xee\xee')
return result
async def connect(self):
await super().connect()
await self.send(b'\xee\xee\xee\xee')
async def recv(self):
return await self.read(struct.unpack('<i', await self.read(4))[0])
def _send(self, data):
self._writer.write(struct.pack('<i', len(data)) + data)
async def send(self, message):
await self.write(struct.pack('<i', len(message)) + message)
async def _recv(self):
return await self._reader.readexactly(
struct.unpack('<i', await self._reader.readexactly(4))[0])

View File

@ -1,7 +1,7 @@
import os
from .connection import Connection
from .tcpabridged import ConnectionTcpAbridged
from .tcpfull import ConnectionTcpFull
from ...crypto import AESModeCTR
@ -11,16 +11,22 @@ class ConnectionTcpObfuscated(ConnectionTcpAbridged):
every message with a randomly generated key using the
AES-CTR mode so the packets are harder to discern.
"""
def __init__(self, *, loop, timeout, proxy=None):
super().__init__(loop=loop, timeout=timeout, proxy=proxy)
self._aes_encrypt, self._aes_decrypt = None, None
self.read = lambda s: self._aes_decrypt.encrypt(self.conn.read(s))
self.write = lambda d: self.conn.write(self._aes_encrypt.encrypt(d))
def __init__(self, ip, port, *, loop):
super().__init__(ip, port, loop=loop)
self._aes_encrypt = None
self._aes_decrypt = None
def _write(self, data):
self._writer.write(self._aes_encrypt.encrypt(data))
async def _read(self, n):
return self._aes_decrypt.encrypt(await self._reader.readexactly(n))
async def connect(self):
await Connection.connect(self)
async def connect(self, ip, port):
result = await ConnectionTcpFull.connect(self, ip, port)
# Obfuscated messages secrets cannot start with any of these
keywords = (b'PVrG', b'GET ', b'POST', b'\xee' * 4)
keywords = (b'PVrG', b'GET ', b'POST', b'\xee\xee\xee\xee')
while True:
random = os.urandom(64)
if (random[0] != b'\xef' and
@ -28,11 +34,11 @@ class ConnectionTcpObfuscated(ConnectionTcpAbridged):
random[4:4] != b'\0\0\0\0'):
break
random = list(random)
random = bytearray(random)
random[56] = random[57] = random[58] = random[59] = 0xef
random_reversed = random[55:7:-1] # Reversed (8, len=48)
# encryption has "continuous buffer" enabled
# Encryption has "continuous buffer" enabled
encrypt_key = bytes(random[8:40])
encrypt_iv = bytes(random[40:56])
decrypt_key = bytes(random_reversed[:32])
@ -42,5 +48,4 @@ class ConnectionTcpObfuscated(ConnectionTcpAbridged):
self._aes_decrypt = AESModeCTR(decrypt_key, decrypt_iv)
random[56:64] = self._aes_encrypt.encrypt(bytes(random))[56:64]
await self.conn.write(bytes(random))
return result
await self._send(random)