Add support for local_ip address binding (#1587)

This commit is contained in:
Stefan 2020-10-07 10:03:19 +02:00 committed by GitHub
parent ce71b3293b
commit d2756cf68f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 33 additions and 5 deletions

View File

@ -110,6 +110,11 @@ class TelegramBaseClient(abc.ABC):
function parameters for PySocks, like ``(type, 'hostname', port)``. function parameters for PySocks, like ``(type, 'hostname', port)``.
See https://github.com/Anorov/PySocks#usage-1 for more. See https://github.com/Anorov/PySocks#usage-1 for more.
local_addr (`str`, optional):
Local host address used to bind the socket to locally.
You only need to use this if you have multiple network cards and
want to use a specific one.
timeout (`int` | `float`, optional): timeout (`int` | `float`, optional):
The timeout in seconds to be used when connecting. The timeout in seconds to be used when connecting.
This is **not** the timeout to be used when ``await``'ing for This is **not** the timeout to be used when ``await``'ing for
@ -215,6 +220,7 @@ class TelegramBaseClient(abc.ABC):
connection: 'typing.Type[Connection]' = ConnectionTcpFull, connection: 'typing.Type[Connection]' = ConnectionTcpFull,
use_ipv6: bool = False, use_ipv6: bool = False,
proxy: typing.Union[tuple, dict] = None, proxy: typing.Union[tuple, dict] = None,
local_addr=None,
timeout: int = 10, timeout: int = 10,
request_retries: int = 5, request_retries: int = 5,
connection_retries: int =5, connection_retries: int =5,
@ -313,12 +319,23 @@ class TelegramBaseClient(abc.ABC):
) )
) )
if local_addr is not None:
if use_ipv6 is False and ':' in local_addr:
raise TypeError(
'A local IPv6 address must only be used with `use_ipv6=True`.'
)
elif use_ipv6 is True and ':' not in local_addr:
raise TypeError(
'`use_ipv6=True` must only be used with a local IPv6 address.'
)
self._raise_last_call_error = raise_last_call_error self._raise_last_call_error = raise_last_call_error
self._request_retries = request_retries self._request_retries = request_retries
self._connection_retries = connection_retries self._connection_retries = connection_retries
self._retry_delay = retry_delay or 0 self._retry_delay = retry_delay or 0
self._proxy = proxy self._proxy = proxy
self._local_addr = local_addr
self._timeout = timeout self._timeout = timeout
self._auto_reconnect = auto_reconnect self._auto_reconnect = auto_reconnect
@ -501,7 +518,8 @@ class TelegramBaseClient(abc.ABC):
self.session.port, self.session.port,
self.session.dc_id, self.session.dc_id,
loggers=self._log, loggers=self._log,
proxy=self._proxy proxy=self._proxy,
local_addr=self._local_addr
)): )):
# We don't want to init or modify anything if we were already connected # We don't want to init or modify anything if we were already connected
return return
@ -663,7 +681,8 @@ class TelegramBaseClient(abc.ABC):
dc.port, dc.port,
dc.id, dc.id,
loggers=self._log, loggers=self._log,
proxy=self._proxy proxy=self._proxy,
local_addr=self._local_addr
)) ))
self._log[__name__].info('Exporting auth for new borrowed sender in %s', dc) self._log[__name__].info('Exporting auth for new borrowed sender in %s', dc)
auth = await self(functions.auth.ExportAuthorizationRequest(dc_id)) auth = await self(functions.auth.ExportAuthorizationRequest(dc_id))
@ -698,7 +717,8 @@ class TelegramBaseClient(abc.ABC):
dc.port, dc.port,
dc.id, dc.id,
loggers=self._log, loggers=self._log,
proxy=self._proxy proxy=self._proxy,
local_addr=self._local_addr
)) ))
state.add_borrow() state.add_borrow()

View File

@ -28,12 +28,13 @@ class Connection(abc.ABC):
# should be one of `PacketCodec` implementations # should be one of `PacketCodec` implementations
packet_codec = None packet_codec = None
def __init__(self, ip, port, dc_id, *, loggers, proxy=None): def __init__(self, ip, port, dc_id, *, loggers, proxy=None, local_addr=None):
self._ip = ip self._ip = ip
self._port = port self._port = port
self._dc_id = dc_id # only for MTProxy, it's an abstraction leak self._dc_id = dc_id # only for MTProxy, it's an abstraction leak
self._log = loggers[__name__] self._log = loggers[__name__]
self._proxy = proxy self._proxy = proxy
self._local_addr = local_addr
self._reader = None self._reader = None
self._writer = None self._writer = None
self._connected = False self._connected = False
@ -46,8 +47,13 @@ class Connection(abc.ABC):
async def _connect(self, timeout=None, ssl=None): async def _connect(self, timeout=None, ssl=None):
if not self._proxy: if not self._proxy:
if self._local_addr is not None:
local_addr = (self._local_addr, None)
else:
local_addr = None
self._reader, self._writer = await asyncio.wait_for( self._reader, self._writer = await asyncio.wait_for(
asyncio.open_connection(self._ip, self._port, ssl=ssl), asyncio.open_connection(self._ip, self._port, ssl=ssl, local_addr=local_addr),
timeout=timeout timeout=timeout
) )
else: else:
@ -64,6 +70,8 @@ class Connection(abc.ABC):
s.set_proxy(*self._proxy) s.set_proxy(*self._proxy)
s.settimeout(timeout) s.settimeout(timeout)
if self._local_addr is not None:
s.bind((self._local_addr, None))
await asyncio.wait_for( await asyncio.wait_for(
asyncio.get_event_loop().sock_connect(s, address), asyncio.get_event_loop().sock_connect(s, address),
timeout=timeout timeout=timeout