Add support for proxy again

This commit is contained in:
Lonami Exo 2018-10-04 17:11:31 +02:00
parent db83709c6b
commit ebde3be820
7 changed files with 65 additions and 35 deletions

View File

@ -202,6 +202,7 @@ class TelegramBaseClient(abc.ABC):
self._request_retries = request_retries or sys.maxsize
self._connection_retries = connection_retries or sys.maxsize
self._proxy = proxy
self._timeout = timeout
self._auto_reconnect = auto_reconnect
@ -307,7 +308,9 @@ class TelegramBaseClient(abc.ABC):
Connects to Telegram.
"""
await self._sender.connect(self.session.auth_key, self._connection(
self.session.server_address, self.session.port, loop=self._loop))
self.session.server_address, self.session.port,
loop=self._loop, proxy=self._proxy
))
await self._sender.send(self._init_with(
functions.help.GetConfigRequest()))
@ -419,7 +422,7 @@ class TelegramBaseClient(abc.ABC):
# with no further clues.
sender = MTProtoSender(self._loop)
await sender.connect(None, self._connection(
dc.ip_address, dc.port, loop=self._loop))
dc.ip_address, dc.port, loop=self._loop, proxy=self._proxy))
__log__.info('Exporting authorization for data center %s', dc)
auth = await self(functions.auth.ExportAuthorizationRequest(dc_id))
req = self._init_with(functions.auth.ImportAuthorizationRequest(

View File

@ -1,7 +1,8 @@
import abc
import asyncio
import logging
import socket
import ssl as ssl_mod
__log__ = logging.getLogger(__name__)
@ -16,11 +17,11 @@ class Connection(abc.ABC):
under any conditions for as long as the user doesn't disconnect or
the input parameters to auto-reconnect dictate otherwise.
"""
# TODO Support proxy
def __init__(self, ip, port, *, loop):
def __init__(self, ip, port, *, loop, proxy=None):
self._ip = ip
self._port = port
self._loop = loop
self._proxy = proxy
self._reader = None
self._writer = None
self._disconnected = asyncio.Event(loop=loop)
@ -31,14 +32,48 @@ class Connection(abc.ABC):
self._send_queue = asyncio.Queue(1)
self._recv_queue = asyncio.Queue(1)
async def connect(self, timeout=None):
async def connect(self, timeout=None, ssl=None):
"""
Establishes a connection with the server.
"""
self._reader, self._writer = await asyncio.wait_for(
asyncio.open_connection(self._ip, self._port, loop=self._loop),
loop=self._loop, timeout=timeout
)
if not self._proxy:
self._reader, self._writer = await asyncio.wait_for(
asyncio.open_connection(
self._ip, self._port, loop=self._loop, ssl=ssl),
loop=self._loop, timeout=timeout
)
else:
import socks
if ':' in self._ip:
mode, address = socket.AF_INET6, (self._ip, self._port, 0, 0)
else:
mode, address = socket.AF_INET, (self._ip, self._port)
s = socks.socksocket(mode, socket.SOCK_STREAM)
if isinstance(self._proxy, dict):
s.set_proxy(**self._proxy)
else:
s.set_proxy(*self._proxy)
s.setblocking(False)
await asyncio.wait_for(
self._loop.sock_connect(s, address),
timeout=timeout,
loop=self._loop
)
if ssl:
self._socket.settimeout(timeout)
self._socket = ssl_mod.wrap_socket(
s,
do_handshake_on_connect=True,
ssl_version=ssl_mod.PROTOCOL_SSLv23,
ciphers='ADH-AES256-SHA'
)
self._socket.setblocking(False)
self._reader, self._writer = await asyncio.open_connection(
self._ip, self._port, loop=self._loop, sock=s
)
self._disconnected.clear()
self._disconnected_future = None

View File

@ -3,20 +3,12 @@ import asyncio
from .connection import Connection
class ConnectionHttp(Connection):
async def connect(self, timeout=None):
# TODO Test if the ssl part works or it needs to be as before:
# dict(ssl_version=ssl.PROTOCOL_SSLv23, ciphers='ADH-AES256-SHA')
self._reader, self._writer = await asyncio.wait_for(
asyncio.open_connection(
self._ip, self._port, loop=self._loop, ssl=True),
loop=self._loop, timeout=timeout
)
SSL_PORT = 443
self._disconnected.clear()
self._disconnected_future = None
self._send_task = self._loop.create_task(self._send_loop())
self._recv_task = self._loop.create_task(self._send_loop())
class ConnectionHttp(Connection):
async def connect(self, timeout=None, ssl=None):
await super().connect(timeout=timeout, ssl=self._port == SSL_PORT)
def _send(self, message):
self._writer.write(

View File

@ -9,8 +9,8 @@ class ConnectionTcpAbridged(Connection):
only require 1 byte if the packet length is less than
508 bytes (127 << 2, which is very common).
"""
async def connect(self, timeout=None):
await super().connect(timeout=timeout)
async def connect(self, timeout=None, ssl=None):
await super().connect(timeout=timeout, ssl=ssl)
await self.send(b'\xef')
def _write(self, data):

View File

@ -10,12 +10,12 @@ class ConnectionTcpFull(Connection):
Default Telegram mode. Sends 12 additional bytes and
needs to calculate the CRC value of the packet itself.
"""
def __init__(self, ip, port, *, loop):
super().__init__(ip, port, loop=loop)
def __init__(self, ip, port, *, loop, proxy=None):
super().__init__(ip, port, loop=loop, proxy=proxy)
self._send_counter = 0
async def connect(self, timeout=None):
await super().connect(timeout=timeout)
async def connect(self, timeout=None, ssl=None):
await super().connect(timeout=timeout, ssl=ssl)
self._send_counter = 0 # Important or Telegram won't reply
def _send(self, data):

View File

@ -8,8 +8,8 @@ class ConnectionTcpIntermediate(Connection):
Intermediate mode between `ConnectionTcpFull` and `ConnectionTcpAbridged`.
Always sends 4 extra bytes for the packet length.
"""
async def connect(self, timeout=None):
await super().connect(timeout=timeout)
async def connect(self, timeout=None, ssl=None):
await super().connect(timeout=timeout, ssl=ssl)
await self.send(b'\xee\xee\xee\xee')
def _send(self, data):

View File

@ -11,8 +11,8 @@ class ConnectionTcpObfuscated(ConnectionTcpAbridged):
every message with a randomly generated key using the
AES-CTR mode so the packets are harder to discern.
"""
def __init__(self, ip, port, *, loop):
super().__init__(ip, port, loop=loop)
def __init__(self, ip, port, *, loop, proxy=None):
super().__init__(ip, port, loop=loop, proxy=proxy)
self._aes_encrypt = None
self._aes_decrypt = None
@ -22,8 +22,8 @@ class ConnectionTcpObfuscated(ConnectionTcpAbridged):
async def _read(self, n):
return self._aes_decrypt.encrypt(await self._reader.readexactly(n))
async def connect(self, timeout=None):
await Connection.connect(self, timeout=timeout)
async def connect(self, timeout=None, ssl=None):
await super().connect(timeout=timeout, ssl=ssl)
# Obfuscated messages secrets cannot start with any of these
keywords = (b'PVrG', b'GET ', b'POST', b'\xee\xee\xee\xee')