From d2756cf68fb8b87ea631e2aa11d548a5dc58e077 Mon Sep 17 00:00:00 2001 From: Stefan Date: Wed, 7 Oct 2020 10:03:19 +0200 Subject: [PATCH] Add support for local_ip address binding (#1587) --- telethon/client/telegrambaseclient.py | 26 ++++++++++++++++++++--- telethon/network/connection/connection.py | 12 +++++++++-- 2 files changed, 33 insertions(+), 5 deletions(-) diff --git a/telethon/client/telegrambaseclient.py b/telethon/client/telegrambaseclient.py index 444a4181..43aae812 100644 --- a/telethon/client/telegrambaseclient.py +++ b/telethon/client/telegrambaseclient.py @@ -110,6 +110,11 @@ class TelegramBaseClient(abc.ABC): function parameters for PySocks, like ``(type, 'hostname', port)``. See https://github.com/Anorov/PySocks#usage-1 for more. + local_addr (`str`, optional): + Local host address used to bind the socket to locally. + You only need to use this if you have multiple network cards and + want to use a specific one. + timeout (`int` | `float`, optional): The timeout in seconds to be used when connecting. This is **not** the timeout to be used when ``await``'ing for @@ -215,6 +220,7 @@ class TelegramBaseClient(abc.ABC): connection: 'typing.Type[Connection]' = ConnectionTcpFull, use_ipv6: bool = False, proxy: typing.Union[tuple, dict] = None, + local_addr=None, timeout: int = 10, request_retries: int = 5, connection_retries: int =5, @@ -313,12 +319,23 @@ class TelegramBaseClient(abc.ABC): ) ) + if local_addr is not None: + if use_ipv6 is False and ':' in local_addr: + raise TypeError( + 'A local IPv6 address must only be used with `use_ipv6=True`.' + ) + elif use_ipv6 is True and ':' not in local_addr: + raise TypeError( + '`use_ipv6=True` must only be used with a local IPv6 address.' + ) + self._raise_last_call_error = raise_last_call_error self._request_retries = request_retries self._connection_retries = connection_retries self._retry_delay = retry_delay or 0 self._proxy = proxy + self._local_addr = local_addr self._timeout = timeout self._auto_reconnect = auto_reconnect @@ -501,7 +518,8 @@ class TelegramBaseClient(abc.ABC): self.session.port, self.session.dc_id, loggers=self._log, - proxy=self._proxy + proxy=self._proxy, + local_addr=self._local_addr )): # We don't want to init or modify anything if we were already connected return @@ -663,7 +681,8 @@ class TelegramBaseClient(abc.ABC): dc.port, dc.id, loggers=self._log, - proxy=self._proxy + proxy=self._proxy, + local_addr=self._local_addr )) self._log[__name__].info('Exporting auth for new borrowed sender in %s', dc) auth = await self(functions.auth.ExportAuthorizationRequest(dc_id)) @@ -698,7 +717,8 @@ class TelegramBaseClient(abc.ABC): dc.port, dc.id, loggers=self._log, - proxy=self._proxy + proxy=self._proxy, + local_addr=self._local_addr )) state.add_borrow() diff --git a/telethon/network/connection/connection.py b/telethon/network/connection/connection.py index 800ff02b..e315a939 100644 --- a/telethon/network/connection/connection.py +++ b/telethon/network/connection/connection.py @@ -28,12 +28,13 @@ class Connection(abc.ABC): # should be one of `PacketCodec` implementations packet_codec = None - def __init__(self, ip, port, dc_id, *, loggers, proxy=None): + def __init__(self, ip, port, dc_id, *, loggers, proxy=None, local_addr=None): self._ip = ip self._port = port self._dc_id = dc_id # only for MTProxy, it's an abstraction leak self._log = loggers[__name__] self._proxy = proxy + self._local_addr = local_addr self._reader = None self._writer = None self._connected = False @@ -46,8 +47,13 @@ class Connection(abc.ABC): async def _connect(self, timeout=None, ssl=None): if not self._proxy: + if self._local_addr is not None: + local_addr = (self._local_addr, None) + else: + local_addr = None + self._reader, self._writer = await asyncio.wait_for( - asyncio.open_connection(self._ip, self._port, ssl=ssl), + asyncio.open_connection(self._ip, self._port, ssl=ssl, local_addr=local_addr), timeout=timeout ) else: @@ -64,6 +70,8 @@ class Connection(abc.ABC): s.set_proxy(*self._proxy) s.settimeout(timeout) + if self._local_addr is not None: + s.bind((self._local_addr, None)) await asyncio.wait_for( asyncio.get_event_loop().sock_connect(s, address), timeout=timeout