add local_addr binding

This commit is contained in:
Stefan Staudinger 2020-10-06 17:22:04 +02:00
parent 05af5d0d74
commit 601830d219
2 changed files with 28 additions and 5 deletions

View File

@ -215,6 +215,7 @@ class TelegramBaseClient(abc.ABC):
connection: 'typing.Type[Connection]' = ConnectionTcpFull,
use_ipv6: bool = False,
proxy: typing.Union[tuple, dict] = None,
local_addr=None,
timeout: int = 10,
request_retries: int = 5,
connection_retries: int =5,
@ -313,12 +314,23 @@ class TelegramBaseClient(abc.ABC):
)
)
if local_addr is not None:
if use_ipv6 is False and ':' in local_addr:
raise TypeError(
'A local IPv6 address must only be used with `use_ipv6=True`.'
)
elif use_ipv6 is True and ':' not in local_addr:
raise TypeError(
'`use_ipv6=True` must only be used with a local IPv6 address.'
)
self._raise_last_call_error = raise_last_call_error
self._request_retries = request_retries
self._connection_retries = connection_retries
self._retry_delay = retry_delay or 0
self._proxy = proxy
self._local_addr = local_addr
self._timeout = timeout
self._auto_reconnect = auto_reconnect
@ -501,7 +513,8 @@ class TelegramBaseClient(abc.ABC):
self.session.port,
self.session.dc_id,
loggers=self._log,
proxy=self._proxy
proxy=self._proxy,
local_addr=self._local_addr
)):
# We don't want to init or modify anything if we were already connected
return
@ -663,7 +676,8 @@ class TelegramBaseClient(abc.ABC):
dc.port,
dc.id,
loggers=self._log,
proxy=self._proxy
proxy=self._proxy,
local_addr=self._local_addr
))
self._log[__name__].info('Exporting auth for new borrowed sender in %s', dc)
auth = await self(functions.auth.ExportAuthorizationRequest(dc_id))
@ -698,7 +712,8 @@ class TelegramBaseClient(abc.ABC):
dc.port,
dc.id,
loggers=self._log,
proxy=self._proxy
proxy=self._proxy,
local_addr=self._local_addr
))
state.add_borrow()

View File

@ -28,12 +28,13 @@ class Connection(abc.ABC):
# should be one of `PacketCodec` implementations
packet_codec = None
def __init__(self, ip, port, dc_id, *, loggers, proxy=None):
def __init__(self, ip, port, dc_id, *, loggers, proxy=None, local_addr=None):
self._ip = ip
self._port = port
self._dc_id = dc_id # only for MTProxy, it's an abstraction leak
self._log = loggers[__name__]
self._proxy = proxy
self._local_addr = local_addr
self._reader = None
self._writer = None
self._connected = False
@ -46,8 +47,13 @@ class Connection(abc.ABC):
async def _connect(self, timeout=None, ssl=None):
if not self._proxy:
if self._local_addr is not None:
local_addr = (self._local_addr, None)
else:
local_addr = None
self._reader, self._writer = await asyncio.wait_for(
asyncio.open_connection(self._ip, self._port, ssl=ssl),
asyncio.open_connection(self._ip, self._port, ssl=ssl, local_addr=local_addr),
timeout=timeout
)
else:
@ -64,6 +70,8 @@ class Connection(abc.ABC):
s.set_proxy(*self._proxy)
s.settimeout(timeout)
if self._local_addr is not None:
s.bind((self._local_addr, None))
await asyncio.wait_for(
asyncio.get_event_loop().sock_connect(s, address),
timeout=timeout