Add support for proxy again

This commit is contained in:
Lonami Exo 2018-10-04 17:11:31 +02:00
parent db83709c6b
commit ebde3be820
7 changed files with 65 additions and 35 deletions

View File

@ -202,6 +202,7 @@ class TelegramBaseClient(abc.ABC):
self._request_retries = request_retries or sys.maxsize self._request_retries = request_retries or sys.maxsize
self._connection_retries = connection_retries or sys.maxsize self._connection_retries = connection_retries or sys.maxsize
self._proxy = proxy
self._timeout = timeout self._timeout = timeout
self._auto_reconnect = auto_reconnect self._auto_reconnect = auto_reconnect
@ -307,7 +308,9 @@ class TelegramBaseClient(abc.ABC):
Connects to Telegram. Connects to Telegram.
""" """
await self._sender.connect(self.session.auth_key, self._connection( await self._sender.connect(self.session.auth_key, self._connection(
self.session.server_address, self.session.port, loop=self._loop)) self.session.server_address, self.session.port,
loop=self._loop, proxy=self._proxy
))
await self._sender.send(self._init_with( await self._sender.send(self._init_with(
functions.help.GetConfigRequest())) functions.help.GetConfigRequest()))
@ -419,7 +422,7 @@ class TelegramBaseClient(abc.ABC):
# with no further clues. # with no further clues.
sender = MTProtoSender(self._loop) sender = MTProtoSender(self._loop)
await sender.connect(None, self._connection( await sender.connect(None, self._connection(
dc.ip_address, dc.port, loop=self._loop)) dc.ip_address, dc.port, loop=self._loop, proxy=self._proxy))
__log__.info('Exporting authorization for data center %s', dc) __log__.info('Exporting authorization for data center %s', dc)
auth = await self(functions.auth.ExportAuthorizationRequest(dc_id)) auth = await self(functions.auth.ExportAuthorizationRequest(dc_id))
req = self._init_with(functions.auth.ImportAuthorizationRequest( req = self._init_with(functions.auth.ImportAuthorizationRequest(

View File

@ -1,7 +1,8 @@
import abc import abc
import asyncio import asyncio
import logging import logging
import socket
import ssl as ssl_mod
__log__ = logging.getLogger(__name__) __log__ = logging.getLogger(__name__)
@ -16,11 +17,11 @@ class Connection(abc.ABC):
under any conditions for as long as the user doesn't disconnect or under any conditions for as long as the user doesn't disconnect or
the input parameters to auto-reconnect dictate otherwise. the input parameters to auto-reconnect dictate otherwise.
""" """
# TODO Support proxy def __init__(self, ip, port, *, loop, proxy=None):
def __init__(self, ip, port, *, loop):
self._ip = ip self._ip = ip
self._port = port self._port = port
self._loop = loop self._loop = loop
self._proxy = proxy
self._reader = None self._reader = None
self._writer = None self._writer = None
self._disconnected = asyncio.Event(loop=loop) self._disconnected = asyncio.Event(loop=loop)
@ -31,14 +32,48 @@ class Connection(abc.ABC):
self._send_queue = asyncio.Queue(1) self._send_queue = asyncio.Queue(1)
self._recv_queue = asyncio.Queue(1) self._recv_queue = asyncio.Queue(1)
async def connect(self, timeout=None): async def connect(self, timeout=None, ssl=None):
""" """
Establishes a connection with the server. Establishes a connection with the server.
""" """
self._reader, self._writer = await asyncio.wait_for( if not self._proxy:
asyncio.open_connection(self._ip, self._port, loop=self._loop), self._reader, self._writer = await asyncio.wait_for(
loop=self._loop, timeout=timeout asyncio.open_connection(
) self._ip, self._port, loop=self._loop, ssl=ssl),
loop=self._loop, timeout=timeout
)
else:
import socks
if ':' in self._ip:
mode, address = socket.AF_INET6, (self._ip, self._port, 0, 0)
else:
mode, address = socket.AF_INET, (self._ip, self._port)
s = socks.socksocket(mode, socket.SOCK_STREAM)
if isinstance(self._proxy, dict):
s.set_proxy(**self._proxy)
else:
s.set_proxy(*self._proxy)
s.setblocking(False)
await asyncio.wait_for(
self._loop.sock_connect(s, address),
timeout=timeout,
loop=self._loop
)
if ssl:
self._socket.settimeout(timeout)
self._socket = ssl_mod.wrap_socket(
s,
do_handshake_on_connect=True,
ssl_version=ssl_mod.PROTOCOL_SSLv23,
ciphers='ADH-AES256-SHA'
)
self._socket.setblocking(False)
self._reader, self._writer = await asyncio.open_connection(
self._ip, self._port, loop=self._loop, sock=s
)
self._disconnected.clear() self._disconnected.clear()
self._disconnected_future = None self._disconnected_future = None

View File

@ -3,20 +3,12 @@ import asyncio
from .connection import Connection from .connection import Connection
class ConnectionHttp(Connection): SSL_PORT = 443
async def connect(self, timeout=None):
# TODO Test if the ssl part works or it needs to be as before:
# dict(ssl_version=ssl.PROTOCOL_SSLv23, ciphers='ADH-AES256-SHA')
self._reader, self._writer = await asyncio.wait_for(
asyncio.open_connection(
self._ip, self._port, loop=self._loop, ssl=True),
loop=self._loop, timeout=timeout
)
self._disconnected.clear()
self._disconnected_future = None class ConnectionHttp(Connection):
self._send_task = self._loop.create_task(self._send_loop()) async def connect(self, timeout=None, ssl=None):
self._recv_task = self._loop.create_task(self._send_loop()) await super().connect(timeout=timeout, ssl=self._port == SSL_PORT)
def _send(self, message): def _send(self, message):
self._writer.write( self._writer.write(

View File

@ -9,8 +9,8 @@ class ConnectionTcpAbridged(Connection):
only require 1 byte if the packet length is less than only require 1 byte if the packet length is less than
508 bytes (127 << 2, which is very common). 508 bytes (127 << 2, which is very common).
""" """
async def connect(self, timeout=None): async def connect(self, timeout=None, ssl=None):
await super().connect(timeout=timeout) await super().connect(timeout=timeout, ssl=ssl)
await self.send(b'\xef') await self.send(b'\xef')
def _write(self, data): def _write(self, data):

View File

@ -10,12 +10,12 @@ class ConnectionTcpFull(Connection):
Default Telegram mode. Sends 12 additional bytes and Default Telegram mode. Sends 12 additional bytes and
needs to calculate the CRC value of the packet itself. needs to calculate the CRC value of the packet itself.
""" """
def __init__(self, ip, port, *, loop): def __init__(self, ip, port, *, loop, proxy=None):
super().__init__(ip, port, loop=loop) super().__init__(ip, port, loop=loop, proxy=proxy)
self._send_counter = 0 self._send_counter = 0
async def connect(self, timeout=None): async def connect(self, timeout=None, ssl=None):
await super().connect(timeout=timeout) await super().connect(timeout=timeout, ssl=ssl)
self._send_counter = 0 # Important or Telegram won't reply self._send_counter = 0 # Important or Telegram won't reply
def _send(self, data): def _send(self, data):

View File

@ -8,8 +8,8 @@ class ConnectionTcpIntermediate(Connection):
Intermediate mode between `ConnectionTcpFull` and `ConnectionTcpAbridged`. Intermediate mode between `ConnectionTcpFull` and `ConnectionTcpAbridged`.
Always sends 4 extra bytes for the packet length. Always sends 4 extra bytes for the packet length.
""" """
async def connect(self, timeout=None): async def connect(self, timeout=None, ssl=None):
await super().connect(timeout=timeout) await super().connect(timeout=timeout, ssl=ssl)
await self.send(b'\xee\xee\xee\xee') await self.send(b'\xee\xee\xee\xee')
def _send(self, data): def _send(self, data):

View File

@ -11,8 +11,8 @@ class ConnectionTcpObfuscated(ConnectionTcpAbridged):
every message with a randomly generated key using the every message with a randomly generated key using the
AES-CTR mode so the packets are harder to discern. AES-CTR mode so the packets are harder to discern.
""" """
def __init__(self, ip, port, *, loop): def __init__(self, ip, port, *, loop, proxy=None):
super().__init__(ip, port, loop=loop) super().__init__(ip, port, loop=loop, proxy=proxy)
self._aes_encrypt = None self._aes_encrypt = None
self._aes_decrypt = None self._aes_decrypt = None
@ -22,8 +22,8 @@ class ConnectionTcpObfuscated(ConnectionTcpAbridged):
async def _read(self, n): async def _read(self, n):
return self._aes_decrypt.encrypt(await self._reader.readexactly(n)) return self._aes_decrypt.encrypt(await self._reader.readexactly(n))
async def connect(self, timeout=None): async def connect(self, timeout=None, ssl=None):
await Connection.connect(self, timeout=timeout) await super().connect(timeout=timeout, ssl=ssl)
# Obfuscated messages secrets cannot start with any of these # Obfuscated messages secrets cannot start with any of these
keywords = (b'PVrG', b'GET ', b'POST', b'\xee\xee\xee\xee') keywords = (b'PVrG', b'GET ', b'POST', b'\xee\xee\xee\xee')