diff --git a/telethon/network/connection/common.py b/telethon/network/connection/common.py deleted file mode 100644 index a57c248e..00000000 --- a/telethon/network/connection/common.py +++ /dev/null @@ -1,74 +0,0 @@ -""" -This module holds the abstract `Connection` class. - -The `Connection.send` and `Connection.recv` methods need **not** to be -safe across several tasks and may use any amount of ``await`` keywords. - -The code using these `Connection`'s should be responsible for using -an ``async with asyncio.Lock:`` block when calling said methods. - -Said subclasses need not to worry about reconnecting either, and -should let the errors propagate instead. -""" -import abc - - -class Connection(abc.ABC): - """ - Represents an abstract connection for Telegram. - - Subclasses should implement the actual protocol - being used when encoding/decoding messages. - """ - def __init__(self, *, loop, timeout, proxy=None): - """ - Initializes a new connection. - - :param loop: the event loop to be used. - :param timeout: timeout to be used for all operations. - :param proxy: whether to use a proxy or not. - """ - self._loop = loop - self._proxy = proxy - self._timeout = timeout - - @abc.abstractmethod - async def connect(self, ip, port): - raise NotImplementedError - - @abc.abstractmethod - def get_timeout(self): - """Returns the timeout used by the connection.""" - raise NotImplementedError - - @abc.abstractmethod - def is_connected(self): - """ - Determines whether the connection is alive or not. - - :return: true if it's connected. - """ - raise NotImplementedError - - @abc.abstractmethod - async def close(self): - """Closes the connection.""" - raise NotImplementedError - - def clone(self): - """Creates a copy of this Connection.""" - return self.__class__( - loop=self._loop, - proxy=self._proxy, - timeout=self._timeout - ) - - @abc.abstractmethod - async def recv(self): - """Receives and unpacks a message""" - raise NotImplementedError - - @abc.abstractmethod - async def send(self, message): - """Encapsulates and sends the given message""" - raise NotImplementedError diff --git a/telethon/network/connection/connection.py b/telethon/network/connection/connection.py index 60ba2f89..5d96fc5e 100644 --- a/telethon/network/connection/connection.py +++ b/telethon/network/connection/connection.py @@ -69,6 +69,7 @@ class Connection(abc.ABC): self._send(await self._send_queue.get()) await self._writer.drain() + # TODO Handle IncompleteReadError and InvalidChecksumError async def _recv_loop(self): """ This loop is constantly putting items on the queue as they're read. diff --git a/telethon/network/connection/http.py b/telethon/network/connection/http.py index 955b9ab3..c346d2f8 100644 --- a/telethon/network/connection/http.py +++ b/telethon/network/connection/http.py @@ -1,62 +1,38 @@ -import errno -import ssl +import asyncio -from .common import Connection -from ...extensions import TcpClient +from .connection import Connection class ConnectionHttp(Connection): - def __init__(self, *, loop, timeout, proxy=None): - super().__init__(loop=loop, timeout=timeout, proxy=proxy) - self.conn = TcpClient( - timeout=self._timeout, loop=self._loop, proxy=self._proxy, - ssl=dict(ssl_version=ssl.PROTOCOL_SSLv23, ciphers='ADH-AES256-SHA') - ) - self.read = self.conn.read - self.write = self.conn.write - self._host = None + async def connect(self): + # TODO Test if the ssl part works or it needs to be as before: + # dict(ssl_version=ssl.PROTOCOL_SSLv23, ciphers='ADH-AES256-SHA') + self._reader, self._writer = await asyncio.open_connection( + self._ip, self._port, loop=self._loop, ssl=True) - async def connect(self, ip, port): - self._host = '{}:{}'.format(ip, port) - try: - await self.conn.connect(ip, port) - except OSError as e: - if e.errno == errno.EISCONN: - return # Already connected, no need to re-set everything up - else: - raise + self._disconnected.clear() + self._send_task = self._loop.create_task(self._send_loop()) + self._recv_task = self._loop.create_task(self._send_loop()) - def get_timeout(self): - return self.conn.timeout - - def is_connected(self): - return self.conn.is_connected - - async def close(self): - self.conn.close() - - async def recv(self): - while True: - line = await self._read_line() - if line.lower().startswith(b'content-length: '): - await self.read(2) - length = int(line[16:-2]) - return await self.read(length) - - async def _read_line(self): - newline = ord('\n') - line = await self.read(1) - while line[-1] != newline: - line += await self.read(1) - return line - - async def send(self, message): - await self.write( + def _send(self, message): + self._writer.write( 'POST /api HTTP/1.1\r\n' - 'Host: {}\r\n' + 'Host: {}:{}\r\n' 'Content-Type: application/x-www-form-urlencoded\r\n' 'Connection: keep-alive\r\n' 'Keep-Alive: timeout=100000, max=10000000\r\n' - 'Content-Length: {}\r\n\r\n'.format(self._host, len(message)) + 'Content-Length: {}\r\n\r\n' + .format(self._ip, self._port, len(message)) .encode('ascii') + message ) + + async def _recv(self): + while True: + line = await self._reader.readline() + if not line or line[-1] != b'\n': + raise asyncio.IncompleteReadError(line, None) + + if line.lower().startswith(b'content-length: '): + await self._reader.readexactly(2) + length = int(line[16:-2]) + return await self._reader.readexactly(length) diff --git a/telethon/network/connection/tcpabridged.py b/telethon/network/connection/tcpabridged.py index d5943908..6352413e 100644 --- a/telethon/network/connection/tcpabridged.py +++ b/telethon/network/connection/tcpabridged.py @@ -1,31 +1,43 @@ import struct -from .tcpfull import ConnectionTcpFull +from .connection import Connection -class ConnectionTcpAbridged(ConnectionTcpFull): +class ConnectionTcpAbridged(Connection): """ This is the mode with the lowest overhead, as it will only require 1 byte if the packet length is less than 508 bytes (127 << 2, which is very common). """ - async def connect(self, ip, port): - result = await super().connect(ip, port) - await self.conn.write(b'\xef') - return result + async def connect(self): + await super().connect() + await self.send(b'\xef') - async def recv(self): - length = struct.unpack('= 127: - length = struct.unpack('> 2 + def _send(self, data): + length = len(data) >> 2 if length < 127: length = struct.pack('B', length) else: length = b'\x7f' + int.to_bytes(length, 3, 'little') - await self.write(length + message) + self._write(length + data) + + async def _recv(self): + length = struct.unpack('= 127: + length = struct.unpack( + '