Make a clear distinction between connection and codec

This commit is contained in:
Lonami Exo 2019-06-16 15:01:05 +02:00
parent f9ca17c99f
commit f082a27ff8
12 changed files with 257 additions and 219 deletions

View File

@ -10,7 +10,7 @@ from .. import version, helpers, __name__ as __base_name__
from ..crypto import rsa
from ..entitycache import EntityCache
from ..extensions import markdown
from ..network import MTProtoSender, Connection, ConnectionTcpFull, TcpMTProxy
from ..network import MTProtoSender, AsyncioConnection, BaseCodec, FullCodec
from ..sessions import Session, SQLiteSession, MemorySession
from ..statecache import StateCache
from ..tl import TLObject, functions, types
@ -167,7 +167,7 @@ class TelegramBaseClient(abc.ABC):
api_id: int,
api_hash: str,
*,
connection: 'typing.Type[Connection]' = ConnectionTcpFull,
connection: 'typing.Type[BaseCodec]' = FullCodec, # TODO rename
use_ipv6: bool = False,
proxy: typing.Union[str, dict] = None,
timeout: int = 10,
@ -257,9 +257,10 @@ class TelegramBaseClient(abc.ABC):
self._auto_reconnect = auto_reconnect
assert isinstance(connection, type)
self._connection = connection
init_proxy = None if not issubclass(connection, TcpMTProxy) else \
types.InputClientProxy(*connection.address_info(self._proxy))
self._codec = connection
# TODO set types.InputClientProxy if appropriated
init_proxy = None
# Used on connection. Capture the variables in a lambda since
# exporting clients need to create this InvokeWithLayerRequest.
@ -415,10 +416,11 @@ class TelegramBaseClient(abc.ABC):
except OSError:
print('Failed to connect')
"""
await self._sender.connect(self._connection(
await self._sender.connect(AsyncioConnection(
self._session.server_address,
self._session.port,
self._session.dc_id,
codec=self._codec(),
loop=self._loop,
loggers=self._log,
proxy=self._proxy

View File

@ -5,10 +5,5 @@ with Telegram's servers and the protocol used (TCP full, abridged, etc.).
from .mtprotoplainsender import MTProtoPlainSender
from .authenticator import do_authentication
from .mtprotosender import MTProtoSender
from .connection import (
Connection,
ConnectionTcpFull, ConnectionTcpIntermediate, ConnectionTcpAbridged,
ConnectionTcpObfuscated, ConnectionTcpMTProxyAbridged,
ConnectionTcpMTProxyIntermediate,
ConnectionTcpMTProxyRandomizedIntermediate, ConnectionHttp, TcpMTProxy
)
from .codec import BaseCodec, FullCodec, IntermediateCodec, AbridgedCodec
from .connection import BaseConnection, AsyncioConnection

View File

@ -0,0 +1,4 @@
from .basecodec import BaseCodec
from .fullcodec import FullCodec
from .intermediatecodec import IntermediateCodec
from .abridgedcodec import AbridgedCodec

View File

@ -0,0 +1,37 @@
import struct
from .basecodec import BaseCodec
class AbridgedCodec(BaseCodec):
"""
This is the mode with the lowest overhead, as it will
only require 1 byte if the packet length is less than
508 bytes (127 << 2, which is very common).
"""
@staticmethod
def header_length():
return 1
@staticmethod
def tag():
return b'\xef' # note: obfuscated tag is this 4 times
def encode_packet(self, data):
length = len(data) >> 2
if length < 127:
length = struct.pack('B', length)
else:
length = b'\x7f' + int.to_bytes(length, 3, 'little')
return length + data
def decode_header(self, header):
if len(header) == 4:
length = struct.unpack('<i', header[1:] + b'\0')[0]
else:
length = struct.unpack('<B', header)[0]
if length >= 127:
return -3 # needs 3 more bytes
return length << 2

View File

@ -0,0 +1,53 @@
import abc
class BaseCodec(abc.ABC):
@staticmethod
@abc.abstractmethod
def header_length():
"""
Returns the initial length of the header.
"""
raise NotImplementedError
@staticmethod
@abc.abstractmethod
def tag():
"""
The bytes tag that identifies the codec.
It may be ``None`` if there is no tag to send.
The tag will be sent upon successful connections to the
server so that it knows which codec we will be using next.
"""
raise NotImplementedError
@abc.abstractmethod
def encode_packet(self, data):
"""
Encodes the given data with the current codec instance.
Should return header + body.
"""
raise NotImplementedError
@abc.abstractmethod
def decode_header(self, header):
"""
Decodes the header.
Should return the length of the body as a positive number.
If more data is needed, a ``-length`` should be returned, where
``length`` is how much more data is needed for the full header.
"""
raise NotImplementedError
def decode_body(self, header, body):
"""
Decodes the body.
The default implementation returns ``body``.
"""
return body

View File

@ -1,43 +1,45 @@
import struct
from zlib import crc32
import zlib
from .connection import Connection, PacketCodec
from .basecodec import BaseCodec
from ...errors import InvalidChecksumError
class FullPacketCodec(PacketCodec):
tag = None
def __init__(self, connection):
super().__init__(connection)
class FullCodec(BaseCodec):
"""
Default Telegram codec. Sends 12 additional bytes and
needs to calculate the CRC value of the packet itself.
"""
def __init__(self):
self._send_counter = 0 # Important or Telegram won't reply
@staticmethod
def header_length():
return 8
@staticmethod
def tag():
return None
def encode_packet(self, data):
# https://core.telegram.org/mtproto#tcp-transport
# total length, sequence number, packet and checksum (CRC32)
length = len(data) + 12
data = struct.pack('<ii', length, self._send_counter) + data
crc = struct.pack('<I', crc32(data))
crc = struct.pack('<I', zlib.crc32(data))
self._send_counter += 1
return data + crc
async def read_packet(self, reader):
packet_len_seq = await reader.readexactly(8) # 4 and 4
packet_len, seq = struct.unpack('<ii', packet_len_seq)
body = await reader.readexactly(packet_len - 8)
def decode_header(self, header):
length, seq = struct.unpack('<ii', header)
return length - 8
def decode_body(self, header, body):
checksum = struct.unpack('<I', body[-4:])[0]
body = body[:-4]
valid_checksum = crc32(packet_len_seq + body)
valid_checksum = zlib.crc32(header + body)
if checksum != valid_checksum:
raise InvalidChecksumError(checksum, valid_checksum)
return body
class ConnectionTcpFull(Connection):
"""
Default Telegram mode. Sends 12 additional bytes and
needs to calculate the CRC value of the packet itself.
"""
packet_codec = FullPacketCodec

View File

@ -2,25 +2,33 @@ import struct
import random
import os
from .connection import Connection, PacketCodec
from .basecodec import BaseCodec
class IntermediatePacketCodec(PacketCodec):
tag = b'\xee\xee\xee\xee'
obfuscate_tag = tag
class IntermediateCodec(BaseCodec):
"""
Intermediate mode between `FullCodec` and `AbridgedCodec`.
Always sends 4 extra bytes for the packet length.
"""
@staticmethod
def header_length():
return 4
@staticmethod
def tag():
return b'\xee\xee\xee\xee' # same as obfuscate tag
def encode_packet(self, data):
return struct.pack('<i', len(data)) + data
async def read_packet(self, reader):
length = struct.unpack('<i', await reader.readexactly(4))[0]
return await reader.readexactly(length)
def decode_header(self, header):
return struct.unpack('<i', header)[0]
class RandomizedIntermediatePacketCodec(IntermediatePacketCodec):
class RandomizedIntermediateCodec(IntermediateCodec):
"""
Data packets are aligned to 4bytes. This codec adds random bytes of size
from 0 to 3 bytes, which are ignored by decoder.
Data packets are aligned to 4 bytes. This codec adds random
bytes of size from 0 to 3 bytes, which are ignored by decoder.
"""
tag = None
obfuscate_tag = b'\xdd\xdd\xdd\xdd'
@ -31,16 +39,9 @@ class RandomizedIntermediatePacketCodec(IntermediatePacketCodec):
return super().encode_packet(data + padding)
async def read_packet(self, reader):
raise NotImplementedError(':)')
packet_with_padding = await super().read_packet(reader)
pad_size = len(packet_with_padding) % 4
if pad_size > 0:
return packet_with_padding[:-pad_size]
return packet_with_padding
class ConnectionTcpIntermediate(Connection):
"""
Intermediate mode between `ConnectionTcpFull` and `ConnectionTcpAbridged`.
Always sends 4 extra bytes for the packet length.
"""
packet_codec = IntermediatePacketCodec

View File

@ -1,12 +1,2 @@
from .connection import Connection
from .tcpfull import ConnectionTcpFull
from .tcpintermediate import ConnectionTcpIntermediate
from .tcpabridged import ConnectionTcpAbridged
from .tcpobfuscated import ConnectionTcpObfuscated
from .tcpmtproxy import (
TcpMTProxy,
ConnectionTcpMTProxyAbridged,
ConnectionTcpMTProxyIntermediate,
ConnectionTcpMTProxyRandomizedIntermediate
)
from .http import ConnectionHttp
from .baseconnection import BaseConnection
from .asyncioconnection import AsyncioConnection

View File

@ -6,11 +6,12 @@ import sys
from ...errors import InvalidChecksumError
from ... import helpers
from .baseconnection import BaseConnection
class Connection(abc.ABC):
class AsyncioConnection(BaseConnection):
"""
The `Connection` class is a wrapper around ``asyncio.open_connection``.
The `AsyncioConnection` class is a wrapper around ``asyncio.open_connection``.
Subclasses will implement different transport modes as atomic operations,
which this class eases doing since the exposed interface simply puts and
@ -24,22 +25,15 @@ class Connection(abc.ABC):
# should be one of `PacketCodec` implementations
packet_codec = None
def __init__(self, ip, port, dc_id, *, loop, loggers, proxy=None):
self._ip = ip
self._port = port
def __init__(self, ip, port, dc_id, *, loop, codec, loggers, proxy=None):
super().__init__(ip, port, loop=loop, codec=codec)
self._dc_id = dc_id # only for MTProxy, it's an abstraction leak
self._loop = loop
self._log = loggers[__name__]
self._proxy = proxy
self._reader = None
self._writer = None
self._connected = False
self._send_task = None
self._recv_task = None
self._codec = None
self._obfuscation = None # TcpObfuscated and MTProxy
self._send_queue = asyncio.Queue(1)
self._recv_queue = asyncio.Queue(1)
async def _connect(self, timeout=None, ssl=None):
if not self._proxy:
@ -76,9 +70,14 @@ class Connection(abc.ABC):
connect_coroutine,
loop=self._loop, timeout=timeout
)
self._codec = self.packet_codec(self)
self._init_conn()
await self._writer.drain()
self._codec.__init__() # reset the codec
if self._codec.tag():
await self._send(self._codec.tag())
@property
def connected(self):
return self._connected
async def connect(self, timeout=None, ssl=None):
"""
@ -87,9 +86,6 @@ class Connection(abc.ABC):
await self._connect(timeout=timeout, ssl=ssl)
self._connected = True
self._send_task = self._loop.create_task(self._send_loop())
self._recv_task = self._loop.create_task(self._recv_loop())
async def disconnect(self):
"""
Disconnects from the server, and clears
@ -97,12 +93,6 @@ class Connection(abc.ABC):
"""
self._connected = False
await helpers._cancel(
self._log,
send_task=self._send_task,
recv_task=self._recv_task
)
if self._writer:
self._writer.close()
if sys.version_info >= (3, 7):
@ -113,104 +103,16 @@ class Connection(abc.ABC):
# Disconnecting should never raise
self._log.warning('Unhandled %s on disconnect: %s', type(e), e)
def send(self, data):
"""
Sends a packet of data through this connection mode.
This method returns a coroutine.
"""
if not self._connected:
raise ConnectionError('Not connected')
return self._send_queue.put(data)
async def recv(self):
"""
Receives a packet of data through this connection mode.
This method returns a coroutine.
"""
while self._connected:
result = await self._recv_queue.get()
if result: # None = sentinel value = keep trying
return result
raise ConnectionError('Not connected')
async def _send_loop(self):
"""
This loop is constantly popping items off the queue to send them.
"""
try:
while self._connected:
self._send(await self._send_queue.get())
async def _send(self, data):
self._writer.write(data)
await self._writer.drain()
except asyncio.CancelledError:
async def _recv(self, length):
return await self._reader.readexactly(length)
class Connection(abc.ABC):
pass
except Exception as e:
if isinstance(e, IOError):
self._log.info('The server closed the connection while sending')
else:
self._log.exception('Unexpected exception in the send loop')
await self.disconnect()
async def _recv_loop(self):
"""
This loop is constantly putting items on the queue as they're read.
"""
while self._connected:
try:
data = await self._recv()
except asyncio.CancelledError:
break
except Exception as e:
if isinstance(e, (IOError, asyncio.IncompleteReadError)):
msg = 'The server closed the connection'
self._log.info(msg)
elif isinstance(e, InvalidChecksumError):
msg = 'The server response had an invalid checksum'
self._log.info(msg)
else:
msg = 'Unexpected exception in the receive loop'
self._log.exception(msg)
await self.disconnect()
# Add a sentinel value to unstuck recv
if self._recv_queue.empty():
self._recv_queue.put_nowait(None)
break
try:
await self._recv_queue.put(data)
except asyncio.CancelledError:
break
def _init_conn(self):
"""
This method will be called after `connect` is called.
After this method finishes, the writer will be drained.
Subclasses should make use of this if they need to send
data to Telegram to indicate which connection mode will
be used.
"""
if self._codec.tag:
self._writer.write(self._codec.tag)
def _send(self, data):
self._writer.write(self._codec.encode_packet(data))
async def _recv(self):
return await self._codec.read_packet(self._reader)
def __str__(self):
return '{}:{}/{}'.format(
self._ip, self._port,
self.__class__.__name__.replace('Connection', '')
)
class ObfuscatedConnection(Connection):

View File

@ -0,0 +1,81 @@
import abc
import asyncio
from ..codec import BaseCodec
class BaseConnection(abc.ABC):
"""
The base connection class.
It offers atomic send and receive methods.
Subclasses are only responsible of sending and receiving data,
since this base class already makes use of the given codec for
correctly adapting the data.
"""
def __init__(self, ip: str, port: int, *, loop: asyncio.AbstractEventLoop, codec: BaseCodec):
self._ip = ip
self._port = port
self._loop = loop
self._codec = codec
self._send_lock = asyncio.Lock(loop=loop)
self._recv_lock = asyncio.Lock(loop=loop)
@property
@abc.abstractmethod
def connected(self):
raise NotImplementedError
@abc.abstractmethod
async def connect(self):
raise NotImplementedError
@abc.abstractmethod
async def disconnect(self):
raise NotImplementedError
@abc.abstractmethod
async def _send(self, data):
raise NotImplementedError
@abc.abstractmethod
async def _recv(self, length):
raise NotImplementedError
async def send(self, data):
if not self.connected:
raise ConnectionError('Not connected')
# TODO Handle asyncio.CancelledError, IOError, Exception
data = self._codec.encode_packet(data)
async with self._send_lock:
return await self._send(data)
async def recv(self):
if not self.connected:
raise ConnectionError('Not connected')
# TODO Handle asyncio.CancelledError, asyncio.IncompleteReadError,
# IOError, InvalidChecksumError, Exception properly
await self._recv_lock.acquire()
try:
header = await self._recv(self._codec.header_length())
length = self._codec.decode_header(header)
while length < 0:
header += await self._recv(-length)
length = self._codec.decode_header(header)
body = await self._recv(length)
return self._codec.decode_body(header, body)
except Exception:
raise ConnectionError
finally:
self._recv_lock.release()
def __str__(self):
return '{}:{}/{}'.format(
self._ip, self._port,
self.__class__.__name__.replace('Connection', '')
)

View File

@ -1,15 +1,16 @@
import asyncio
from .connection import Connection, PacketCodec
SSL_PORT = 443
class HttpPacketCodec(PacketCodec):
class HttpPacketCodec:
tag = None
obfuscate_tag = None
def __init__(self):
raise NotImplementedError('Not migrated yet')
def encode_packet(self, data):
return ('POST /api HTTP/1.1\r\n'
'Host: {}:{}\r\n'
@ -32,8 +33,11 @@ class HttpPacketCodec(PacketCodec):
return await reader.readexactly(length)
class ConnectionHttp(Connection):
class ConnectionHttp:
packet_codec = HttpPacketCodec
def __init__(self):
raise NotImplementedError('Not migrated yet')
async def connect(self, timeout=None, ssl=None):
await super().connect(timeout=timeout, ssl=self._port == SSL_PORT)

View File

@ -1,33 +0,0 @@
import struct
from .connection import Connection, PacketCodec
class AbridgedPacketCodec(PacketCodec):
tag = b'\xef'
obfuscate_tag = b'\xef\xef\xef\xef'
def encode_packet(self, data):
length = len(data) >> 2
if length < 127:
length = struct.pack('B', length)
else:
length = b'\x7f' + int.to_bytes(length, 3, 'little')
return length + data
async def read_packet(self, reader):
length = struct.unpack('<B', await reader.readexactly(1))[0]
if length >= 127:
length = struct.unpack(
'<i', await reader.readexactly(3) + b'\0')[0]
return await reader.readexactly(length << 2)
class ConnectionTcpAbridged(Connection):
"""
This is the mode with the lowest overhead, as it will
only require 1 byte if the packet length is less than
508 bytes (127 << 2, which is very common).
"""
packet_codec = AbridgedPacketCodec