Make a clear distinction between connection and codec

This commit is contained in:
Lonami Exo 2019-06-16 15:01:05 +02:00
parent f9ca17c99f
commit f082a27ff8
12 changed files with 257 additions and 219 deletions

View File

@ -10,7 +10,7 @@ from .. import version, helpers, __name__ as __base_name__
from ..crypto import rsa from ..crypto import rsa
from ..entitycache import EntityCache from ..entitycache import EntityCache
from ..extensions import markdown from ..extensions import markdown
from ..network import MTProtoSender, Connection, ConnectionTcpFull, TcpMTProxy from ..network import MTProtoSender, AsyncioConnection, BaseCodec, FullCodec
from ..sessions import Session, SQLiteSession, MemorySession from ..sessions import Session, SQLiteSession, MemorySession
from ..statecache import StateCache from ..statecache import StateCache
from ..tl import TLObject, functions, types from ..tl import TLObject, functions, types
@ -167,7 +167,7 @@ class TelegramBaseClient(abc.ABC):
api_id: int, api_id: int,
api_hash: str, api_hash: str,
*, *,
connection: 'typing.Type[Connection]' = ConnectionTcpFull, connection: 'typing.Type[BaseCodec]' = FullCodec, # TODO rename
use_ipv6: bool = False, use_ipv6: bool = False,
proxy: typing.Union[str, dict] = None, proxy: typing.Union[str, dict] = None,
timeout: int = 10, timeout: int = 10,
@ -257,9 +257,10 @@ class TelegramBaseClient(abc.ABC):
self._auto_reconnect = auto_reconnect self._auto_reconnect = auto_reconnect
assert isinstance(connection, type) assert isinstance(connection, type)
self._connection = connection self._codec = connection
init_proxy = None if not issubclass(connection, TcpMTProxy) else \
types.InputClientProxy(*connection.address_info(self._proxy)) # TODO set types.InputClientProxy if appropriated
init_proxy = None
# Used on connection. Capture the variables in a lambda since # Used on connection. Capture the variables in a lambda since
# exporting clients need to create this InvokeWithLayerRequest. # exporting clients need to create this InvokeWithLayerRequest.
@ -415,10 +416,11 @@ class TelegramBaseClient(abc.ABC):
except OSError: except OSError:
print('Failed to connect') print('Failed to connect')
""" """
await self._sender.connect(self._connection( await self._sender.connect(AsyncioConnection(
self._session.server_address, self._session.server_address,
self._session.port, self._session.port,
self._session.dc_id, self._session.dc_id,
codec=self._codec(),
loop=self._loop, loop=self._loop,
loggers=self._log, loggers=self._log,
proxy=self._proxy proxy=self._proxy

View File

@ -5,10 +5,5 @@ with Telegram's servers and the protocol used (TCP full, abridged, etc.).
from .mtprotoplainsender import MTProtoPlainSender from .mtprotoplainsender import MTProtoPlainSender
from .authenticator import do_authentication from .authenticator import do_authentication
from .mtprotosender import MTProtoSender from .mtprotosender import MTProtoSender
from .connection import ( from .codec import BaseCodec, FullCodec, IntermediateCodec, AbridgedCodec
Connection, from .connection import BaseConnection, AsyncioConnection
ConnectionTcpFull, ConnectionTcpIntermediate, ConnectionTcpAbridged,
ConnectionTcpObfuscated, ConnectionTcpMTProxyAbridged,
ConnectionTcpMTProxyIntermediate,
ConnectionTcpMTProxyRandomizedIntermediate, ConnectionHttp, TcpMTProxy
)

View File

@ -0,0 +1,4 @@
from .basecodec import BaseCodec
from .fullcodec import FullCodec
from .intermediatecodec import IntermediateCodec
from .abridgedcodec import AbridgedCodec

View File

@ -0,0 +1,37 @@
import struct
from .basecodec import BaseCodec
class AbridgedCodec(BaseCodec):
"""
This is the mode with the lowest overhead, as it will
only require 1 byte if the packet length is less than
508 bytes (127 << 2, which is very common).
"""
@staticmethod
def header_length():
return 1
@staticmethod
def tag():
return b'\xef' # note: obfuscated tag is this 4 times
def encode_packet(self, data):
length = len(data) >> 2
if length < 127:
length = struct.pack('B', length)
else:
length = b'\x7f' + int.to_bytes(length, 3, 'little')
return length + data
def decode_header(self, header):
if len(header) == 4:
length = struct.unpack('<i', header[1:] + b'\0')[0]
else:
length = struct.unpack('<B', header)[0]
if length >= 127:
return -3 # needs 3 more bytes
return length << 2

View File

@ -0,0 +1,53 @@
import abc
class BaseCodec(abc.ABC):
@staticmethod
@abc.abstractmethod
def header_length():
"""
Returns the initial length of the header.
"""
raise NotImplementedError
@staticmethod
@abc.abstractmethod
def tag():
"""
The bytes tag that identifies the codec.
It may be ``None`` if there is no tag to send.
The tag will be sent upon successful connections to the
server so that it knows which codec we will be using next.
"""
raise NotImplementedError
@abc.abstractmethod
def encode_packet(self, data):
"""
Encodes the given data with the current codec instance.
Should return header + body.
"""
raise NotImplementedError
@abc.abstractmethod
def decode_header(self, header):
"""
Decodes the header.
Should return the length of the body as a positive number.
If more data is needed, a ``-length`` should be returned, where
``length`` is how much more data is needed for the full header.
"""
raise NotImplementedError
def decode_body(self, header, body):
"""
Decodes the body.
The default implementation returns ``body``.
"""
return body

View File

@ -1,43 +1,45 @@
import struct import struct
from zlib import crc32 import zlib
from .connection import Connection, PacketCodec from .basecodec import BaseCodec
from ...errors import InvalidChecksumError from ...errors import InvalidChecksumError
class FullPacketCodec(PacketCodec): class FullCodec(BaseCodec):
tag = None """
Default Telegram codec. Sends 12 additional bytes and
def __init__(self, connection): needs to calculate the CRC value of the packet itself.
super().__init__(connection) """
def __init__(self):
self._send_counter = 0 # Important or Telegram won't reply self._send_counter = 0 # Important or Telegram won't reply
@staticmethod
def header_length():
return 8
@staticmethod
def tag():
return None
def encode_packet(self, data): def encode_packet(self, data):
# https://core.telegram.org/mtproto#tcp-transport # https://core.telegram.org/mtproto#tcp-transport
# total length, sequence number, packet and checksum (CRC32) # total length, sequence number, packet and checksum (CRC32)
length = len(data) + 12 length = len(data) + 12
data = struct.pack('<ii', length, self._send_counter) + data data = struct.pack('<ii', length, self._send_counter) + data
crc = struct.pack('<I', crc32(data)) crc = struct.pack('<I', zlib.crc32(data))
self._send_counter += 1 self._send_counter += 1
return data + crc return data + crc
async def read_packet(self, reader): def decode_header(self, header):
packet_len_seq = await reader.readexactly(8) # 4 and 4 length, seq = struct.unpack('<ii', header)
packet_len, seq = struct.unpack('<ii', packet_len_seq) return length - 8
body = await reader.readexactly(packet_len - 8)
def decode_body(self, header, body):
checksum = struct.unpack('<I', body[-4:])[0] checksum = struct.unpack('<I', body[-4:])[0]
body = body[:-4] body = body[:-4]
valid_checksum = crc32(packet_len_seq + body) valid_checksum = zlib.crc32(header + body)
if checksum != valid_checksum: if checksum != valid_checksum:
raise InvalidChecksumError(checksum, valid_checksum) raise InvalidChecksumError(checksum, valid_checksum)
return body return body
class ConnectionTcpFull(Connection):
"""
Default Telegram mode. Sends 12 additional bytes and
needs to calculate the CRC value of the packet itself.
"""
packet_codec = FullPacketCodec

View File

@ -2,25 +2,33 @@ import struct
import random import random
import os import os
from .connection import Connection, PacketCodec from .basecodec import BaseCodec
class IntermediatePacketCodec(PacketCodec): class IntermediateCodec(BaseCodec):
tag = b'\xee\xee\xee\xee' """
obfuscate_tag = tag Intermediate mode between `FullCodec` and `AbridgedCodec`.
Always sends 4 extra bytes for the packet length.
"""
@staticmethod
def header_length():
return 4
@staticmethod
def tag():
return b'\xee\xee\xee\xee' # same as obfuscate tag
def encode_packet(self, data): def encode_packet(self, data):
return struct.pack('<i', len(data)) + data return struct.pack('<i', len(data)) + data
async def read_packet(self, reader): def decode_header(self, header):
length = struct.unpack('<i', await reader.readexactly(4))[0] return struct.unpack('<i', header)[0]
return await reader.readexactly(length)
class RandomizedIntermediatePacketCodec(IntermediatePacketCodec): class RandomizedIntermediateCodec(IntermediateCodec):
""" """
Data packets are aligned to 4bytes. This codec adds random bytes of size Data packets are aligned to 4 bytes. This codec adds random
from 0 to 3 bytes, which are ignored by decoder. bytes of size from 0 to 3 bytes, which are ignored by decoder.
""" """
tag = None tag = None
obfuscate_tag = b'\xdd\xdd\xdd\xdd' obfuscate_tag = b'\xdd\xdd\xdd\xdd'
@ -31,16 +39,9 @@ class RandomizedIntermediatePacketCodec(IntermediatePacketCodec):
return super().encode_packet(data + padding) return super().encode_packet(data + padding)
async def read_packet(self, reader): async def read_packet(self, reader):
raise NotImplementedError(':)')
packet_with_padding = await super().read_packet(reader) packet_with_padding = await super().read_packet(reader)
pad_size = len(packet_with_padding) % 4 pad_size = len(packet_with_padding) % 4
if pad_size > 0: if pad_size > 0:
return packet_with_padding[:-pad_size] return packet_with_padding[:-pad_size]
return packet_with_padding return packet_with_padding
class ConnectionTcpIntermediate(Connection):
"""
Intermediate mode between `ConnectionTcpFull` and `ConnectionTcpAbridged`.
Always sends 4 extra bytes for the packet length.
"""
packet_codec = IntermediatePacketCodec

View File

@ -1,12 +1,2 @@
from .connection import Connection from .baseconnection import BaseConnection
from .tcpfull import ConnectionTcpFull from .asyncioconnection import AsyncioConnection
from .tcpintermediate import ConnectionTcpIntermediate
from .tcpabridged import ConnectionTcpAbridged
from .tcpobfuscated import ConnectionTcpObfuscated
from .tcpmtproxy import (
TcpMTProxy,
ConnectionTcpMTProxyAbridged,
ConnectionTcpMTProxyIntermediate,
ConnectionTcpMTProxyRandomizedIntermediate
)
from .http import ConnectionHttp

View File

@ -6,11 +6,12 @@ import sys
from ...errors import InvalidChecksumError from ...errors import InvalidChecksumError
from ... import helpers from ... import helpers
from .baseconnection import BaseConnection
class Connection(abc.ABC): class AsyncioConnection(BaseConnection):
""" """
The `Connection` class is a wrapper around ``asyncio.open_connection``. The `AsyncioConnection` class is a wrapper around ``asyncio.open_connection``.
Subclasses will implement different transport modes as atomic operations, Subclasses will implement different transport modes as atomic operations,
which this class eases doing since the exposed interface simply puts and which this class eases doing since the exposed interface simply puts and
@ -24,22 +25,15 @@ class Connection(abc.ABC):
# should be one of `PacketCodec` implementations # should be one of `PacketCodec` implementations
packet_codec = None packet_codec = None
def __init__(self, ip, port, dc_id, *, loop, loggers, proxy=None): def __init__(self, ip, port, dc_id, *, loop, codec, loggers, proxy=None):
self._ip = ip super().__init__(ip, port, loop=loop, codec=codec)
self._port = port
self._dc_id = dc_id # only for MTProxy, it's an abstraction leak self._dc_id = dc_id # only for MTProxy, it's an abstraction leak
self._loop = loop
self._log = loggers[__name__] self._log = loggers[__name__]
self._proxy = proxy self._proxy = proxy
self._reader = None self._reader = None
self._writer = None self._writer = None
self._connected = False self._connected = False
self._send_task = None
self._recv_task = None
self._codec = None
self._obfuscation = None # TcpObfuscated and MTProxy self._obfuscation = None # TcpObfuscated and MTProxy
self._send_queue = asyncio.Queue(1)
self._recv_queue = asyncio.Queue(1)
async def _connect(self, timeout=None, ssl=None): async def _connect(self, timeout=None, ssl=None):
if not self._proxy: if not self._proxy:
@ -76,9 +70,14 @@ class Connection(abc.ABC):
connect_coroutine, connect_coroutine,
loop=self._loop, timeout=timeout loop=self._loop, timeout=timeout
) )
self._codec = self.packet_codec(self)
self._init_conn() self._codec.__init__() # reset the codec
await self._writer.drain() if self._codec.tag():
await self._send(self._codec.tag())
@property
def connected(self):
return self._connected
async def connect(self, timeout=None, ssl=None): async def connect(self, timeout=None, ssl=None):
""" """
@ -87,9 +86,6 @@ class Connection(abc.ABC):
await self._connect(timeout=timeout, ssl=ssl) await self._connect(timeout=timeout, ssl=ssl)
self._connected = True self._connected = True
self._send_task = self._loop.create_task(self._send_loop())
self._recv_task = self._loop.create_task(self._recv_loop())
async def disconnect(self): async def disconnect(self):
""" """
Disconnects from the server, and clears Disconnects from the server, and clears
@ -97,12 +93,6 @@ class Connection(abc.ABC):
""" """
self._connected = False self._connected = False
await helpers._cancel(
self._log,
send_task=self._send_task,
recv_task=self._recv_task
)
if self._writer: if self._writer:
self._writer.close() self._writer.close()
if sys.version_info >= (3, 7): if sys.version_info >= (3, 7):
@ -113,104 +103,16 @@ class Connection(abc.ABC):
# Disconnecting should never raise # Disconnecting should never raise
self._log.warning('Unhandled %s on disconnect: %s', type(e), e) self._log.warning('Unhandled %s on disconnect: %s', type(e), e)
def send(self, data): async def _send(self, data):
""" self._writer.write(data)
Sends a packet of data through this connection mode. await self._writer.drain()
This method returns a coroutine. async def _recv(self, length):
""" return await self._reader.readexactly(length)
if not self._connected:
raise ConnectionError('Not connected')
return self._send_queue.put(data)
async def recv(self): class Connection(abc.ABC):
""" pass
Receives a packet of data through this connection mode.
This method returns a coroutine.
"""
while self._connected:
result = await self._recv_queue.get()
if result: # None = sentinel value = keep trying
return result
raise ConnectionError('Not connected')
async def _send_loop(self):
"""
This loop is constantly popping items off the queue to send them.
"""
try:
while self._connected:
self._send(await self._send_queue.get())
await self._writer.drain()
except asyncio.CancelledError:
pass
except Exception as e:
if isinstance(e, IOError):
self._log.info('The server closed the connection while sending')
else:
self._log.exception('Unexpected exception in the send loop')
await self.disconnect()
async def _recv_loop(self):
"""
This loop is constantly putting items on the queue as they're read.
"""
while self._connected:
try:
data = await self._recv()
except asyncio.CancelledError:
break
except Exception as e:
if isinstance(e, (IOError, asyncio.IncompleteReadError)):
msg = 'The server closed the connection'
self._log.info(msg)
elif isinstance(e, InvalidChecksumError):
msg = 'The server response had an invalid checksum'
self._log.info(msg)
else:
msg = 'Unexpected exception in the receive loop'
self._log.exception(msg)
await self.disconnect()
# Add a sentinel value to unstuck recv
if self._recv_queue.empty():
self._recv_queue.put_nowait(None)
break
try:
await self._recv_queue.put(data)
except asyncio.CancelledError:
break
def _init_conn(self):
"""
This method will be called after `connect` is called.
After this method finishes, the writer will be drained.
Subclasses should make use of this if they need to send
data to Telegram to indicate which connection mode will
be used.
"""
if self._codec.tag:
self._writer.write(self._codec.tag)
def _send(self, data):
self._writer.write(self._codec.encode_packet(data))
async def _recv(self):
return await self._codec.read_packet(self._reader)
def __str__(self):
return '{}:{}/{}'.format(
self._ip, self._port,
self.__class__.__name__.replace('Connection', '')
)
class ObfuscatedConnection(Connection): class ObfuscatedConnection(Connection):

View File

@ -0,0 +1,81 @@
import abc
import asyncio
from ..codec import BaseCodec
class BaseConnection(abc.ABC):
"""
The base connection class.
It offers atomic send and receive methods.
Subclasses are only responsible of sending and receiving data,
since this base class already makes use of the given codec for
correctly adapting the data.
"""
def __init__(self, ip: str, port: int, *, loop: asyncio.AbstractEventLoop, codec: BaseCodec):
self._ip = ip
self._port = port
self._loop = loop
self._codec = codec
self._send_lock = asyncio.Lock(loop=loop)
self._recv_lock = asyncio.Lock(loop=loop)
@property
@abc.abstractmethod
def connected(self):
raise NotImplementedError
@abc.abstractmethod
async def connect(self):
raise NotImplementedError
@abc.abstractmethod
async def disconnect(self):
raise NotImplementedError
@abc.abstractmethod
async def _send(self, data):
raise NotImplementedError
@abc.abstractmethod
async def _recv(self, length):
raise NotImplementedError
async def send(self, data):
if not self.connected:
raise ConnectionError('Not connected')
# TODO Handle asyncio.CancelledError, IOError, Exception
data = self._codec.encode_packet(data)
async with self._send_lock:
return await self._send(data)
async def recv(self):
if not self.connected:
raise ConnectionError('Not connected')
# TODO Handle asyncio.CancelledError, asyncio.IncompleteReadError,
# IOError, InvalidChecksumError, Exception properly
await self._recv_lock.acquire()
try:
header = await self._recv(self._codec.header_length())
length = self._codec.decode_header(header)
while length < 0:
header += await self._recv(-length)
length = self._codec.decode_header(header)
body = await self._recv(length)
return self._codec.decode_body(header, body)
except Exception:
raise ConnectionError
finally:
self._recv_lock.release()
def __str__(self):
return '{}:{}/{}'.format(
self._ip, self._port,
self.__class__.__name__.replace('Connection', '')
)

View File

@ -1,15 +1,16 @@
import asyncio import asyncio
from .connection import Connection, PacketCodec
SSL_PORT = 443 SSL_PORT = 443
class HttpPacketCodec(PacketCodec): class HttpPacketCodec:
tag = None tag = None
obfuscate_tag = None obfuscate_tag = None
def __init__(self):
raise NotImplementedError('Not migrated yet')
def encode_packet(self, data): def encode_packet(self, data):
return ('POST /api HTTP/1.1\r\n' return ('POST /api HTTP/1.1\r\n'
'Host: {}:{}\r\n' 'Host: {}:{}\r\n'
@ -32,8 +33,11 @@ class HttpPacketCodec(PacketCodec):
return await reader.readexactly(length) return await reader.readexactly(length)
class ConnectionHttp(Connection): class ConnectionHttp:
packet_codec = HttpPacketCodec packet_codec = HttpPacketCodec
def __init__(self):
raise NotImplementedError('Not migrated yet')
async def connect(self, timeout=None, ssl=None): async def connect(self, timeout=None, ssl=None):
await super().connect(timeout=timeout, ssl=self._port == SSL_PORT) await super().connect(timeout=timeout, ssl=self._port == SSL_PORT)

View File

@ -1,33 +0,0 @@
import struct
from .connection import Connection, PacketCodec
class AbridgedPacketCodec(PacketCodec):
tag = b'\xef'
obfuscate_tag = b'\xef\xef\xef\xef'
def encode_packet(self, data):
length = len(data) >> 2
if length < 127:
length = struct.pack('B', length)
else:
length = b'\x7f' + int.to_bytes(length, 3, 'little')
return length + data
async def read_packet(self, reader):
length = struct.unpack('<B', await reader.readexactly(1))[0]
if length >= 127:
length = struct.unpack(
'<i', await reader.readexactly(3) + b'\0')[0]
return await reader.readexactly(length << 2)
class ConnectionTcpAbridged(Connection):
"""
This is the mode with the lowest overhead, as it will
only require 1 byte if the packet length is less than
508 bytes (127 << 2, which is very common).
"""
packet_codec = AbridgedPacketCodec