mirror of
https://github.com/LonamiWebs/Telethon.git
synced 2025-03-12 15:38:03 +03:00
Make a clear distinction between connection and codec
This commit is contained in:
parent
f9ca17c99f
commit
f082a27ff8
|
@ -10,7 +10,7 @@ from .. import version, helpers, __name__ as __base_name__
|
||||||
from ..crypto import rsa
|
from ..crypto import rsa
|
||||||
from ..entitycache import EntityCache
|
from ..entitycache import EntityCache
|
||||||
from ..extensions import markdown
|
from ..extensions import markdown
|
||||||
from ..network import MTProtoSender, Connection, ConnectionTcpFull, TcpMTProxy
|
from ..network import MTProtoSender, AsyncioConnection, BaseCodec, FullCodec
|
||||||
from ..sessions import Session, SQLiteSession, MemorySession
|
from ..sessions import Session, SQLiteSession, MemorySession
|
||||||
from ..statecache import StateCache
|
from ..statecache import StateCache
|
||||||
from ..tl import TLObject, functions, types
|
from ..tl import TLObject, functions, types
|
||||||
|
@ -167,7 +167,7 @@ class TelegramBaseClient(abc.ABC):
|
||||||
api_id: int,
|
api_id: int,
|
||||||
api_hash: str,
|
api_hash: str,
|
||||||
*,
|
*,
|
||||||
connection: 'typing.Type[Connection]' = ConnectionTcpFull,
|
connection: 'typing.Type[BaseCodec]' = FullCodec, # TODO rename
|
||||||
use_ipv6: bool = False,
|
use_ipv6: bool = False,
|
||||||
proxy: typing.Union[str, dict] = None,
|
proxy: typing.Union[str, dict] = None,
|
||||||
timeout: int = 10,
|
timeout: int = 10,
|
||||||
|
@ -257,9 +257,10 @@ class TelegramBaseClient(abc.ABC):
|
||||||
self._auto_reconnect = auto_reconnect
|
self._auto_reconnect = auto_reconnect
|
||||||
|
|
||||||
assert isinstance(connection, type)
|
assert isinstance(connection, type)
|
||||||
self._connection = connection
|
self._codec = connection
|
||||||
init_proxy = None if not issubclass(connection, TcpMTProxy) else \
|
|
||||||
types.InputClientProxy(*connection.address_info(self._proxy))
|
# TODO set types.InputClientProxy if appropriated
|
||||||
|
init_proxy = None
|
||||||
|
|
||||||
# Used on connection. Capture the variables in a lambda since
|
# Used on connection. Capture the variables in a lambda since
|
||||||
# exporting clients need to create this InvokeWithLayerRequest.
|
# exporting clients need to create this InvokeWithLayerRequest.
|
||||||
|
@ -415,10 +416,11 @@ class TelegramBaseClient(abc.ABC):
|
||||||
except OSError:
|
except OSError:
|
||||||
print('Failed to connect')
|
print('Failed to connect')
|
||||||
"""
|
"""
|
||||||
await self._sender.connect(self._connection(
|
await self._sender.connect(AsyncioConnection(
|
||||||
self._session.server_address,
|
self._session.server_address,
|
||||||
self._session.port,
|
self._session.port,
|
||||||
self._session.dc_id,
|
self._session.dc_id,
|
||||||
|
codec=self._codec(),
|
||||||
loop=self._loop,
|
loop=self._loop,
|
||||||
loggers=self._log,
|
loggers=self._log,
|
||||||
proxy=self._proxy
|
proxy=self._proxy
|
||||||
|
|
|
@ -5,10 +5,5 @@ with Telegram's servers and the protocol used (TCP full, abridged, etc.).
|
||||||
from .mtprotoplainsender import MTProtoPlainSender
|
from .mtprotoplainsender import MTProtoPlainSender
|
||||||
from .authenticator import do_authentication
|
from .authenticator import do_authentication
|
||||||
from .mtprotosender import MTProtoSender
|
from .mtprotosender import MTProtoSender
|
||||||
from .connection import (
|
from .codec import BaseCodec, FullCodec, IntermediateCodec, AbridgedCodec
|
||||||
Connection,
|
from .connection import BaseConnection, AsyncioConnection
|
||||||
ConnectionTcpFull, ConnectionTcpIntermediate, ConnectionTcpAbridged,
|
|
||||||
ConnectionTcpObfuscated, ConnectionTcpMTProxyAbridged,
|
|
||||||
ConnectionTcpMTProxyIntermediate,
|
|
||||||
ConnectionTcpMTProxyRandomizedIntermediate, ConnectionHttp, TcpMTProxy
|
|
||||||
)
|
|
||||||
|
|
4
telethon/network/codec/__init__.py
Normal file
4
telethon/network/codec/__init__.py
Normal file
|
@ -0,0 +1,4 @@
|
||||||
|
from .basecodec import BaseCodec
|
||||||
|
from .fullcodec import FullCodec
|
||||||
|
from .intermediatecodec import IntermediateCodec
|
||||||
|
from .abridgedcodec import AbridgedCodec
|
37
telethon/network/codec/abridgedcodec.py
Normal file
37
telethon/network/codec/abridgedcodec.py
Normal file
|
@ -0,0 +1,37 @@
|
||||||
|
import struct
|
||||||
|
|
||||||
|
from .basecodec import BaseCodec
|
||||||
|
|
||||||
|
|
||||||
|
class AbridgedCodec(BaseCodec):
|
||||||
|
"""
|
||||||
|
This is the mode with the lowest overhead, as it will
|
||||||
|
only require 1 byte if the packet length is less than
|
||||||
|
508 bytes (127 << 2, which is very common).
|
||||||
|
"""
|
||||||
|
@staticmethod
|
||||||
|
def header_length():
|
||||||
|
return 1
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def tag():
|
||||||
|
return b'\xef' # note: obfuscated tag is this 4 times
|
||||||
|
|
||||||
|
def encode_packet(self, data):
|
||||||
|
length = len(data) >> 2
|
||||||
|
if length < 127:
|
||||||
|
length = struct.pack('B', length)
|
||||||
|
else:
|
||||||
|
length = b'\x7f' + int.to_bytes(length, 3, 'little')
|
||||||
|
|
||||||
|
return length + data
|
||||||
|
|
||||||
|
def decode_header(self, header):
|
||||||
|
if len(header) == 4:
|
||||||
|
length = struct.unpack('<i', header[1:] + b'\0')[0]
|
||||||
|
else:
|
||||||
|
length = struct.unpack('<B', header)[0]
|
||||||
|
if length >= 127:
|
||||||
|
return -3 # needs 3 more bytes
|
||||||
|
|
||||||
|
return length << 2
|
53
telethon/network/codec/basecodec.py
Normal file
53
telethon/network/codec/basecodec.py
Normal file
|
@ -0,0 +1,53 @@
|
||||||
|
import abc
|
||||||
|
|
||||||
|
|
||||||
|
class BaseCodec(abc.ABC):
|
||||||
|
@staticmethod
|
||||||
|
@abc.abstractmethod
|
||||||
|
def header_length():
|
||||||
|
"""
|
||||||
|
Returns the initial length of the header.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
@abc.abstractmethod
|
||||||
|
def tag():
|
||||||
|
"""
|
||||||
|
The bytes tag that identifies the codec.
|
||||||
|
|
||||||
|
It may be ``None`` if there is no tag to send.
|
||||||
|
|
||||||
|
The tag will be sent upon successful connections to the
|
||||||
|
server so that it knows which codec we will be using next.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def encode_packet(self, data):
|
||||||
|
"""
|
||||||
|
Encodes the given data with the current codec instance.
|
||||||
|
|
||||||
|
Should return header + body.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def decode_header(self, header):
|
||||||
|
"""
|
||||||
|
Decodes the header.
|
||||||
|
|
||||||
|
Should return the length of the body as a positive number.
|
||||||
|
|
||||||
|
If more data is needed, a ``-length`` should be returned, where
|
||||||
|
``length`` is how much more data is needed for the full header.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def decode_body(self, header, body):
|
||||||
|
"""
|
||||||
|
Decodes the body.
|
||||||
|
|
||||||
|
The default implementation returns ``body``.
|
||||||
|
"""
|
||||||
|
return body
|
|
@ -1,43 +1,45 @@
|
||||||
import struct
|
import struct
|
||||||
from zlib import crc32
|
import zlib
|
||||||
|
|
||||||
from .connection import Connection, PacketCodec
|
from .basecodec import BaseCodec
|
||||||
from ...errors import InvalidChecksumError
|
from ...errors import InvalidChecksumError
|
||||||
|
|
||||||
|
|
||||||
class FullPacketCodec(PacketCodec):
|
class FullCodec(BaseCodec):
|
||||||
tag = None
|
"""
|
||||||
|
Default Telegram codec. Sends 12 additional bytes and
|
||||||
def __init__(self, connection):
|
needs to calculate the CRC value of the packet itself.
|
||||||
super().__init__(connection)
|
"""
|
||||||
|
def __init__(self):
|
||||||
self._send_counter = 0 # Important or Telegram won't reply
|
self._send_counter = 0 # Important or Telegram won't reply
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def header_length():
|
||||||
|
return 8
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def tag():
|
||||||
|
return None
|
||||||
|
|
||||||
def encode_packet(self, data):
|
def encode_packet(self, data):
|
||||||
# https://core.telegram.org/mtproto#tcp-transport
|
# https://core.telegram.org/mtproto#tcp-transport
|
||||||
# total length, sequence number, packet and checksum (CRC32)
|
# total length, sequence number, packet and checksum (CRC32)
|
||||||
length = len(data) + 12
|
length = len(data) + 12
|
||||||
data = struct.pack('<ii', length, self._send_counter) + data
|
data = struct.pack('<ii', length, self._send_counter) + data
|
||||||
crc = struct.pack('<I', crc32(data))
|
crc = struct.pack('<I', zlib.crc32(data))
|
||||||
self._send_counter += 1
|
self._send_counter += 1
|
||||||
return data + crc
|
return data + crc
|
||||||
|
|
||||||
async def read_packet(self, reader):
|
def decode_header(self, header):
|
||||||
packet_len_seq = await reader.readexactly(8) # 4 and 4
|
length, seq = struct.unpack('<ii', header)
|
||||||
packet_len, seq = struct.unpack('<ii', packet_len_seq)
|
return length - 8
|
||||||
body = await reader.readexactly(packet_len - 8)
|
|
||||||
|
def decode_body(self, header, body):
|
||||||
checksum = struct.unpack('<I', body[-4:])[0]
|
checksum = struct.unpack('<I', body[-4:])[0]
|
||||||
body = body[:-4]
|
body = body[:-4]
|
||||||
|
|
||||||
valid_checksum = crc32(packet_len_seq + body)
|
valid_checksum = zlib.crc32(header + body)
|
||||||
if checksum != valid_checksum:
|
if checksum != valid_checksum:
|
||||||
raise InvalidChecksumError(checksum, valid_checksum)
|
raise InvalidChecksumError(checksum, valid_checksum)
|
||||||
|
|
||||||
return body
|
return body
|
||||||
|
|
||||||
|
|
||||||
class ConnectionTcpFull(Connection):
|
|
||||||
"""
|
|
||||||
Default Telegram mode. Sends 12 additional bytes and
|
|
||||||
needs to calculate the CRC value of the packet itself.
|
|
||||||
"""
|
|
||||||
packet_codec = FullPacketCodec
|
|
|
@ -2,25 +2,33 @@ import struct
|
||||||
import random
|
import random
|
||||||
import os
|
import os
|
||||||
|
|
||||||
from .connection import Connection, PacketCodec
|
from .basecodec import BaseCodec
|
||||||
|
|
||||||
|
|
||||||
class IntermediatePacketCodec(PacketCodec):
|
class IntermediateCodec(BaseCodec):
|
||||||
tag = b'\xee\xee\xee\xee'
|
"""
|
||||||
obfuscate_tag = tag
|
Intermediate mode between `FullCodec` and `AbridgedCodec`.
|
||||||
|
Always sends 4 extra bytes for the packet length.
|
||||||
|
"""
|
||||||
|
@staticmethod
|
||||||
|
def header_length():
|
||||||
|
return 4
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def tag():
|
||||||
|
return b'\xee\xee\xee\xee' # same as obfuscate tag
|
||||||
|
|
||||||
def encode_packet(self, data):
|
def encode_packet(self, data):
|
||||||
return struct.pack('<i', len(data)) + data
|
return struct.pack('<i', len(data)) + data
|
||||||
|
|
||||||
async def read_packet(self, reader):
|
def decode_header(self, header):
|
||||||
length = struct.unpack('<i', await reader.readexactly(4))[0]
|
return struct.unpack('<i', header)[0]
|
||||||
return await reader.readexactly(length)
|
|
||||||
|
|
||||||
|
|
||||||
class RandomizedIntermediatePacketCodec(IntermediatePacketCodec):
|
class RandomizedIntermediateCodec(IntermediateCodec):
|
||||||
"""
|
"""
|
||||||
Data packets are aligned to 4bytes. This codec adds random bytes of size
|
Data packets are aligned to 4 bytes. This codec adds random
|
||||||
from 0 to 3 bytes, which are ignored by decoder.
|
bytes of size from 0 to 3 bytes, which are ignored by decoder.
|
||||||
"""
|
"""
|
||||||
tag = None
|
tag = None
|
||||||
obfuscate_tag = b'\xdd\xdd\xdd\xdd'
|
obfuscate_tag = b'\xdd\xdd\xdd\xdd'
|
||||||
|
@ -31,16 +39,9 @@ class RandomizedIntermediatePacketCodec(IntermediatePacketCodec):
|
||||||
return super().encode_packet(data + padding)
|
return super().encode_packet(data + padding)
|
||||||
|
|
||||||
async def read_packet(self, reader):
|
async def read_packet(self, reader):
|
||||||
|
raise NotImplementedError(':)')
|
||||||
packet_with_padding = await super().read_packet(reader)
|
packet_with_padding = await super().read_packet(reader)
|
||||||
pad_size = len(packet_with_padding) % 4
|
pad_size = len(packet_with_padding) % 4
|
||||||
if pad_size > 0:
|
if pad_size > 0:
|
||||||
return packet_with_padding[:-pad_size]
|
return packet_with_padding[:-pad_size]
|
||||||
return packet_with_padding
|
return packet_with_padding
|
||||||
|
|
||||||
|
|
||||||
class ConnectionTcpIntermediate(Connection):
|
|
||||||
"""
|
|
||||||
Intermediate mode between `ConnectionTcpFull` and `ConnectionTcpAbridged`.
|
|
||||||
Always sends 4 extra bytes for the packet length.
|
|
||||||
"""
|
|
||||||
packet_codec = IntermediatePacketCodec
|
|
|
@ -1,12 +1,2 @@
|
||||||
from .connection import Connection
|
from .baseconnection import BaseConnection
|
||||||
from .tcpfull import ConnectionTcpFull
|
from .asyncioconnection import AsyncioConnection
|
||||||
from .tcpintermediate import ConnectionTcpIntermediate
|
|
||||||
from .tcpabridged import ConnectionTcpAbridged
|
|
||||||
from .tcpobfuscated import ConnectionTcpObfuscated
|
|
||||||
from .tcpmtproxy import (
|
|
||||||
TcpMTProxy,
|
|
||||||
ConnectionTcpMTProxyAbridged,
|
|
||||||
ConnectionTcpMTProxyIntermediate,
|
|
||||||
ConnectionTcpMTProxyRandomizedIntermediate
|
|
||||||
)
|
|
||||||
from .http import ConnectionHttp
|
|
||||||
|
|
|
@ -6,11 +6,12 @@ import sys
|
||||||
|
|
||||||
from ...errors import InvalidChecksumError
|
from ...errors import InvalidChecksumError
|
||||||
from ... import helpers
|
from ... import helpers
|
||||||
|
from .baseconnection import BaseConnection
|
||||||
|
|
||||||
|
|
||||||
class Connection(abc.ABC):
|
class AsyncioConnection(BaseConnection):
|
||||||
"""
|
"""
|
||||||
The `Connection` class is a wrapper around ``asyncio.open_connection``.
|
The `AsyncioConnection` class is a wrapper around ``asyncio.open_connection``.
|
||||||
|
|
||||||
Subclasses will implement different transport modes as atomic operations,
|
Subclasses will implement different transport modes as atomic operations,
|
||||||
which this class eases doing since the exposed interface simply puts and
|
which this class eases doing since the exposed interface simply puts and
|
||||||
|
@ -24,22 +25,15 @@ class Connection(abc.ABC):
|
||||||
# should be one of `PacketCodec` implementations
|
# should be one of `PacketCodec` implementations
|
||||||
packet_codec = None
|
packet_codec = None
|
||||||
|
|
||||||
def __init__(self, ip, port, dc_id, *, loop, loggers, proxy=None):
|
def __init__(self, ip, port, dc_id, *, loop, codec, loggers, proxy=None):
|
||||||
self._ip = ip
|
super().__init__(ip, port, loop=loop, codec=codec)
|
||||||
self._port = port
|
|
||||||
self._dc_id = dc_id # only for MTProxy, it's an abstraction leak
|
self._dc_id = dc_id # only for MTProxy, it's an abstraction leak
|
||||||
self._loop = loop
|
|
||||||
self._log = loggers[__name__]
|
self._log = loggers[__name__]
|
||||||
self._proxy = proxy
|
self._proxy = proxy
|
||||||
self._reader = None
|
self._reader = None
|
||||||
self._writer = None
|
self._writer = None
|
||||||
self._connected = False
|
self._connected = False
|
||||||
self._send_task = None
|
|
||||||
self._recv_task = None
|
|
||||||
self._codec = None
|
|
||||||
self._obfuscation = None # TcpObfuscated and MTProxy
|
self._obfuscation = None # TcpObfuscated and MTProxy
|
||||||
self._send_queue = asyncio.Queue(1)
|
|
||||||
self._recv_queue = asyncio.Queue(1)
|
|
||||||
|
|
||||||
async def _connect(self, timeout=None, ssl=None):
|
async def _connect(self, timeout=None, ssl=None):
|
||||||
if not self._proxy:
|
if not self._proxy:
|
||||||
|
@ -76,9 +70,14 @@ class Connection(abc.ABC):
|
||||||
connect_coroutine,
|
connect_coroutine,
|
||||||
loop=self._loop, timeout=timeout
|
loop=self._loop, timeout=timeout
|
||||||
)
|
)
|
||||||
self._codec = self.packet_codec(self)
|
|
||||||
self._init_conn()
|
self._codec.__init__() # reset the codec
|
||||||
await self._writer.drain()
|
if self._codec.tag():
|
||||||
|
await self._send(self._codec.tag())
|
||||||
|
|
||||||
|
@property
|
||||||
|
def connected(self):
|
||||||
|
return self._connected
|
||||||
|
|
||||||
async def connect(self, timeout=None, ssl=None):
|
async def connect(self, timeout=None, ssl=None):
|
||||||
"""
|
"""
|
||||||
|
@ -87,9 +86,6 @@ class Connection(abc.ABC):
|
||||||
await self._connect(timeout=timeout, ssl=ssl)
|
await self._connect(timeout=timeout, ssl=ssl)
|
||||||
self._connected = True
|
self._connected = True
|
||||||
|
|
||||||
self._send_task = self._loop.create_task(self._send_loop())
|
|
||||||
self._recv_task = self._loop.create_task(self._recv_loop())
|
|
||||||
|
|
||||||
async def disconnect(self):
|
async def disconnect(self):
|
||||||
"""
|
"""
|
||||||
Disconnects from the server, and clears
|
Disconnects from the server, and clears
|
||||||
|
@ -97,12 +93,6 @@ class Connection(abc.ABC):
|
||||||
"""
|
"""
|
||||||
self._connected = False
|
self._connected = False
|
||||||
|
|
||||||
await helpers._cancel(
|
|
||||||
self._log,
|
|
||||||
send_task=self._send_task,
|
|
||||||
recv_task=self._recv_task
|
|
||||||
)
|
|
||||||
|
|
||||||
if self._writer:
|
if self._writer:
|
||||||
self._writer.close()
|
self._writer.close()
|
||||||
if sys.version_info >= (3, 7):
|
if sys.version_info >= (3, 7):
|
||||||
|
@ -113,104 +103,16 @@ class Connection(abc.ABC):
|
||||||
# Disconnecting should never raise
|
# Disconnecting should never raise
|
||||||
self._log.warning('Unhandled %s on disconnect: %s', type(e), e)
|
self._log.warning('Unhandled %s on disconnect: %s', type(e), e)
|
||||||
|
|
||||||
def send(self, data):
|
async def _send(self, data):
|
||||||
"""
|
self._writer.write(data)
|
||||||
Sends a packet of data through this connection mode.
|
await self._writer.drain()
|
||||||
|
|
||||||
This method returns a coroutine.
|
async def _recv(self, length):
|
||||||
"""
|
return await self._reader.readexactly(length)
|
||||||
if not self._connected:
|
|
||||||
raise ConnectionError('Not connected')
|
|
||||||
|
|
||||||
return self._send_queue.put(data)
|
|
||||||
|
|
||||||
async def recv(self):
|
class Connection(abc.ABC):
|
||||||
"""
|
pass
|
||||||
Receives a packet of data through this connection mode.
|
|
||||||
|
|
||||||
This method returns a coroutine.
|
|
||||||
"""
|
|
||||||
while self._connected:
|
|
||||||
result = await self._recv_queue.get()
|
|
||||||
if result: # None = sentinel value = keep trying
|
|
||||||
return result
|
|
||||||
|
|
||||||
raise ConnectionError('Not connected')
|
|
||||||
|
|
||||||
async def _send_loop(self):
|
|
||||||
"""
|
|
||||||
This loop is constantly popping items off the queue to send them.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
while self._connected:
|
|
||||||
self._send(await self._send_queue.get())
|
|
||||||
await self._writer.drain()
|
|
||||||
except asyncio.CancelledError:
|
|
||||||
pass
|
|
||||||
except Exception as e:
|
|
||||||
if isinstance(e, IOError):
|
|
||||||
self._log.info('The server closed the connection while sending')
|
|
||||||
else:
|
|
||||||
self._log.exception('Unexpected exception in the send loop')
|
|
||||||
|
|
||||||
await self.disconnect()
|
|
||||||
|
|
||||||
async def _recv_loop(self):
|
|
||||||
"""
|
|
||||||
This loop is constantly putting items on the queue as they're read.
|
|
||||||
"""
|
|
||||||
while self._connected:
|
|
||||||
try:
|
|
||||||
data = await self._recv()
|
|
||||||
except asyncio.CancelledError:
|
|
||||||
break
|
|
||||||
except Exception as e:
|
|
||||||
if isinstance(e, (IOError, asyncio.IncompleteReadError)):
|
|
||||||
msg = 'The server closed the connection'
|
|
||||||
self._log.info(msg)
|
|
||||||
elif isinstance(e, InvalidChecksumError):
|
|
||||||
msg = 'The server response had an invalid checksum'
|
|
||||||
self._log.info(msg)
|
|
||||||
else:
|
|
||||||
msg = 'Unexpected exception in the receive loop'
|
|
||||||
self._log.exception(msg)
|
|
||||||
|
|
||||||
await self.disconnect()
|
|
||||||
|
|
||||||
# Add a sentinel value to unstuck recv
|
|
||||||
if self._recv_queue.empty():
|
|
||||||
self._recv_queue.put_nowait(None)
|
|
||||||
|
|
||||||
break
|
|
||||||
|
|
||||||
try:
|
|
||||||
await self._recv_queue.put(data)
|
|
||||||
except asyncio.CancelledError:
|
|
||||||
break
|
|
||||||
|
|
||||||
def _init_conn(self):
|
|
||||||
"""
|
|
||||||
This method will be called after `connect` is called.
|
|
||||||
After this method finishes, the writer will be drained.
|
|
||||||
|
|
||||||
Subclasses should make use of this if they need to send
|
|
||||||
data to Telegram to indicate which connection mode will
|
|
||||||
be used.
|
|
||||||
"""
|
|
||||||
if self._codec.tag:
|
|
||||||
self._writer.write(self._codec.tag)
|
|
||||||
|
|
||||||
def _send(self, data):
|
|
||||||
self._writer.write(self._codec.encode_packet(data))
|
|
||||||
|
|
||||||
async def _recv(self):
|
|
||||||
return await self._codec.read_packet(self._reader)
|
|
||||||
|
|
||||||
def __str__(self):
|
|
||||||
return '{}:{}/{}'.format(
|
|
||||||
self._ip, self._port,
|
|
||||||
self.__class__.__name__.replace('Connection', '')
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class ObfuscatedConnection(Connection):
|
class ObfuscatedConnection(Connection):
|
81
telethon/network/connection/baseconnection.py
Normal file
81
telethon/network/connection/baseconnection.py
Normal file
|
@ -0,0 +1,81 @@
|
||||||
|
import abc
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
from ..codec import BaseCodec
|
||||||
|
|
||||||
|
|
||||||
|
class BaseConnection(abc.ABC):
|
||||||
|
"""
|
||||||
|
The base connection class.
|
||||||
|
|
||||||
|
It offers atomic send and receive methods.
|
||||||
|
|
||||||
|
Subclasses are only responsible of sending and receiving data,
|
||||||
|
since this base class already makes use of the given codec for
|
||||||
|
correctly adapting the data.
|
||||||
|
"""
|
||||||
|
def __init__(self, ip: str, port: int, *, loop: asyncio.AbstractEventLoop, codec: BaseCodec):
|
||||||
|
self._ip = ip
|
||||||
|
self._port = port
|
||||||
|
self._loop = loop
|
||||||
|
self._codec = codec
|
||||||
|
self._send_lock = asyncio.Lock(loop=loop)
|
||||||
|
self._recv_lock = asyncio.Lock(loop=loop)
|
||||||
|
|
||||||
|
@property
|
||||||
|
@abc.abstractmethod
|
||||||
|
def connected(self):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
async def connect(self):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
async def disconnect(self):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
async def _send(self, data):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
async def _recv(self, length):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
async def send(self, data):
|
||||||
|
if not self.connected:
|
||||||
|
raise ConnectionError('Not connected')
|
||||||
|
|
||||||
|
# TODO Handle asyncio.CancelledError, IOError, Exception
|
||||||
|
data = self._codec.encode_packet(data)
|
||||||
|
async with self._send_lock:
|
||||||
|
return await self._send(data)
|
||||||
|
|
||||||
|
async def recv(self):
|
||||||
|
if not self.connected:
|
||||||
|
raise ConnectionError('Not connected')
|
||||||
|
|
||||||
|
# TODO Handle asyncio.CancelledError, asyncio.IncompleteReadError,
|
||||||
|
# IOError, InvalidChecksumError, Exception properly
|
||||||
|
await self._recv_lock.acquire()
|
||||||
|
try:
|
||||||
|
header = await self._recv(self._codec.header_length())
|
||||||
|
|
||||||
|
length = self._codec.decode_header(header)
|
||||||
|
while length < 0:
|
||||||
|
header += await self._recv(-length)
|
||||||
|
length = self._codec.decode_header(header)
|
||||||
|
|
||||||
|
body = await self._recv(length)
|
||||||
|
return self._codec.decode_body(header, body)
|
||||||
|
except Exception:
|
||||||
|
raise ConnectionError
|
||||||
|
finally:
|
||||||
|
self._recv_lock.release()
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return '{}:{}/{}'.format(
|
||||||
|
self._ip, self._port,
|
||||||
|
self.__class__.__name__.replace('Connection', '')
|
||||||
|
)
|
|
@ -1,15 +1,16 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
|
|
||||||
from .connection import Connection, PacketCodec
|
|
||||||
|
|
||||||
|
|
||||||
SSL_PORT = 443
|
SSL_PORT = 443
|
||||||
|
|
||||||
|
|
||||||
class HttpPacketCodec(PacketCodec):
|
class HttpPacketCodec:
|
||||||
tag = None
|
tag = None
|
||||||
obfuscate_tag = None
|
obfuscate_tag = None
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
raise NotImplementedError('Not migrated yet')
|
||||||
|
|
||||||
def encode_packet(self, data):
|
def encode_packet(self, data):
|
||||||
return ('POST /api HTTP/1.1\r\n'
|
return ('POST /api HTTP/1.1\r\n'
|
||||||
'Host: {}:{}\r\n'
|
'Host: {}:{}\r\n'
|
||||||
|
@ -32,8 +33,11 @@ class HttpPacketCodec(PacketCodec):
|
||||||
return await reader.readexactly(length)
|
return await reader.readexactly(length)
|
||||||
|
|
||||||
|
|
||||||
class ConnectionHttp(Connection):
|
class ConnectionHttp:
|
||||||
packet_codec = HttpPacketCodec
|
packet_codec = HttpPacketCodec
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
raise NotImplementedError('Not migrated yet')
|
||||||
|
|
||||||
async def connect(self, timeout=None, ssl=None):
|
async def connect(self, timeout=None, ssl=None):
|
||||||
await super().connect(timeout=timeout, ssl=self._port == SSL_PORT)
|
await super().connect(timeout=timeout, ssl=self._port == SSL_PORT)
|
||||||
|
|
|
@ -1,33 +0,0 @@
|
||||||
import struct
|
|
||||||
|
|
||||||
from .connection import Connection, PacketCodec
|
|
||||||
|
|
||||||
|
|
||||||
class AbridgedPacketCodec(PacketCodec):
|
|
||||||
tag = b'\xef'
|
|
||||||
obfuscate_tag = b'\xef\xef\xef\xef'
|
|
||||||
|
|
||||||
def encode_packet(self, data):
|
|
||||||
length = len(data) >> 2
|
|
||||||
if length < 127:
|
|
||||||
length = struct.pack('B', length)
|
|
||||||
else:
|
|
||||||
length = b'\x7f' + int.to_bytes(length, 3, 'little')
|
|
||||||
return length + data
|
|
||||||
|
|
||||||
async def read_packet(self, reader):
|
|
||||||
length = struct.unpack('<B', await reader.readexactly(1))[0]
|
|
||||||
if length >= 127:
|
|
||||||
length = struct.unpack(
|
|
||||||
'<i', await reader.readexactly(3) + b'\0')[0]
|
|
||||||
|
|
||||||
return await reader.readexactly(length << 2)
|
|
||||||
|
|
||||||
|
|
||||||
class ConnectionTcpAbridged(Connection):
|
|
||||||
"""
|
|
||||||
This is the mode with the lowest overhead, as it will
|
|
||||||
only require 1 byte if the packet length is less than
|
|
||||||
508 bytes (127 << 2, which is very common).
|
|
||||||
"""
|
|
||||||
packet_codec = AbridgedPacketCodec
|
|
Loading…
Reference in New Issue
Block a user