Telethon/telethon/network/connection/connection.py

180 lines
5.5 KiB
Python

import abc
import asyncio
import logging
import socket
import ssl as ssl_mod
__log__ = logging.getLogger(__name__)
class Connection(abc.ABC):
"""
The `Connection` class is a wrapper around ``asyncio.open_connection``.
Subclasses are meant to communicate with this class through a queue.
This class provides a reliable interface that will stay connected
under any conditions for as long as the user doesn't disconnect or
the input parameters to auto-reconnect dictate otherwise.
"""
def __init__(self, ip, port, *, loop, proxy=None):
self._ip = ip
self._port = port
self._loop = loop
self._proxy = proxy
self._reader = None
self._writer = None
self._disconnected = asyncio.Event(loop=loop)
self._disconnected.set()
self._disconnected_future = None
self._send_task = None
self._recv_task = None
self._send_queue = asyncio.Queue(1)
self._recv_queue = asyncio.Queue(1)
async def connect(self, timeout=None, ssl=None):
"""
Establishes a connection with the server.
"""
if not self._proxy:
self._reader, self._writer = await asyncio.wait_for(
asyncio.open_connection(
self._ip, self._port, loop=self._loop, ssl=ssl),
loop=self._loop, timeout=timeout
)
else:
import socks
if ':' in self._ip:
mode, address = socket.AF_INET6, (self._ip, self._port, 0, 0)
else:
mode, address = socket.AF_INET, (self._ip, self._port)
s = socks.socksocket(mode, socket.SOCK_STREAM)
if isinstance(self._proxy, dict):
s.set_proxy(**self._proxy)
else:
s.set_proxy(*self._proxy)
s.setblocking(False)
await asyncio.wait_for(
self._loop.sock_connect(s, address),
timeout=timeout,
loop=self._loop
)
if ssl:
self._socket.settimeout(timeout)
self._socket = ssl_mod.wrap_socket(
s,
do_handshake_on_connect=True,
ssl_version=ssl_mod.PROTOCOL_SSLv23,
ciphers='ADH-AES256-SHA'
)
self._socket.setblocking(False)
self._reader, self._writer = \
await asyncio.open_connection(sock=s, loop=self._loop)
self._disconnected.clear()
self._disconnected_future = None
self._send_task = self._loop.create_task(self._send_loop())
self._recv_task = self._loop.create_task(self._recv_loop())
def disconnect(self):
"""
Disconnects from the server.
"""
self._disconnected.set()
if self._send_task:
self._send_task.cancel()
if self._recv_task:
self._recv_task.cancel()
if self._writer:
self._writer.close()
@property
def disconnected(self):
if not self._disconnected_future:
self._disconnected_future = \
self._loop.create_task(self._disconnected.wait())
return self._disconnected_future
def clone(self):
"""
Creates a clone of the connection.
"""
return self.__class__(self._ip, self._port, loop=self._loop)
def send(self, data):
"""
Sends a packet of data through this connection mode.
This method returns a coroutine.
"""
return self._send_queue.put(data)
async def recv(self):
"""
Receives a packet of data through this connection mode.
This method returns a coroutine.
"""
ok, result = await self._recv_queue.get()
if ok:
return result
else:
raise result from None
async def _send_loop(self):
"""
This loop is constantly popping items off the queue to send them.
"""
try:
while not self._disconnected.is_set():
self._send(await self._send_queue.get())
await self._writer.drain()
except asyncio.CancelledError:
pass
except Exception:
logging.exception('Unhandled exception in the sending loop')
self.disconnect()
async def _recv_loop(self):
"""
This loop is constantly putting items on the queue as they're read.
"""
try:
while not self._disconnected.is_set():
data = await self._recv()
await self._recv_queue.put((True, data))
except asyncio.CancelledError:
pass
except Exception as e:
await self._recv_queue.put((False, e))
self.disconnect()
@abc.abstractmethod
def _send(self, data):
"""
This method should be implemented differently under each
connection mode and serialize the data into the packet
the way it should be sent through `self._writer`.
"""
raise NotImplementedError
@abc.abstractmethod
async def _recv(self):
"""
This method should be implemented differently under each
connection mode and deserialize the data from the packet
the way it should be read from `self._reader`.
"""
raise NotImplementedError
def __str__(self):
return '{}:{}/{}'.format(
self._ip, self._port,
self.__class__.__name__.replace('Connection', '')
)