Support custom connectors

This commit is contained in:
Lonami Exo 2023-10-29 21:45:21 +01:00
parent d80c6b3bb4
commit 6e88264b28
8 changed files with 161 additions and 51 deletions

View File

@ -107,3 +107,11 @@ Private definitions
.. autoclass:: InFileLike .. autoclass:: InFileLike
.. autoclass:: OutFileLike .. autoclass:: OutFileLike
.. currentmodule:: telethon._impl.mtsender.sender
.. autoclass:: AsyncReader
.. autoclass:: AsyncWriter
.. autoclass:: Connector

View File

@ -49,7 +49,7 @@ async def complete_login(client: Client, auth: abcs.auth.Authorization) -> User:
try: try:
await client._storage.save(client._session) await client._storage.save(client._session)
except Exception: except Exception:
client._logger.exception( client._config.base_logger.exception(
"failed to save session upon login; you may need to login again in future runs" "failed to save session upon login; you may need to login again in future runs"
) )
@ -59,7 +59,7 @@ async def complete_login(client: Client, auth: abcs.auth.Authorization) -> User:
async def handle_migrate(client: Client, dc_id: Optional[int]) -> None: async def handle_migrate(client: Client, dc_id: Optional[int]) -> None:
assert dc_id is not None assert dc_id is not None
sender, client._session.dcs = await connect_sender( sender, client._session.dcs = await connect_sender(
client._config, client._session.dcs, DataCenter(id=dc_id), client._logger client._config, client._session.dcs, DataCenter(id=dc_id)
) )
async with client._sender_lock: async with client._sender_lock:
client._sender = sender client._sender = sender

View File

@ -19,12 +19,11 @@ from typing import (
Union, Union,
) )
from telethon._impl.session.session import DataCenter
from ....version import __version__ as default_version from ....version import __version__ as default_version
from ...mtsender import Sender from ...mtsender import Connector, Sender
from ...session import ( from ...session import (
ChatHashCache, ChatHashCache,
DataCenter,
MemorySession, MemorySession,
MessageBox, MessageBox,
PackedChat, PackedChat,
@ -193,6 +192,13 @@ class Client:
:param datacenter: :param datacenter:
Override the datacenter to connect to. Override the datacenter to connect to.
Useful to connect to one of Telegram's test servers. Useful to connect to one of Telegram's test servers.
:param connector:
Asynchronous function called to connect to a remote address.
By default, this is :func:`asyncio.open_connection`.
In order to use proxies, you can set a custom connector.
See :class:`~telethon._impl.mtsender.sender.Connector` for more details.
""" """
def __init__( def __init__(
@ -212,10 +218,9 @@ class Client:
system_lang_code: Optional[str] = None, system_lang_code: Optional[str] = None,
lang_code: Optional[str] = None, lang_code: Optional[str] = None,
datacenter: Optional[DataCenter] = None, datacenter: Optional[DataCenter] = None,
connector: Optional[Connector] = None,
) -> None: ) -> None:
self._logger = logger or logging.getLogger( base_logger = logger or logging.getLogger(__package__[: __package__.index(".")])
__package__[: __package__.index(".")]
)
self._sender: Optional[Sender] = None self._sender: Optional[Sender] = None
self._sender_lock = asyncio.Lock() self._sender_lock = asyncio.Lock()
@ -240,11 +245,13 @@ class Client:
if flood_sleep_threshold is None if flood_sleep_threshold is None
else flood_sleep_threshold, else flood_sleep_threshold,
update_queue_limit=update_queue_limit, update_queue_limit=update_queue_limit,
base_logger=base_logger,
connector=connector or (lambda ip, port: asyncio.open_connection(ip, port)),
) )
self._session = Session() self._session = Session()
self._message_box = MessageBox(base_logger=self._logger) self._message_box = MessageBox(base_logger=base_logger)
self._chat_hashes = ChatHashCache(None) self._chat_hashes = ChatHashCache(None)
self._last_update_limit_warn: Optional[float] = None self._last_update_limit_warn: Optional[float] = None
self._updates: asyncio.Queue[ self._updates: asyncio.Queue[

View File

@ -10,9 +10,8 @@ from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, TypeVar
from ....version import __version__ from ....version import __version__
from ...mtproto import BadStatus, Full, RpcError from ...mtproto import BadStatus, Full, RpcError
from ...mtsender import Sender from ...mtsender import Connector, Sender
from ...mtsender import connect as connect_without_auth from ...mtsender import connect as do_connect_sender
from ...mtsender import connect_with_auth
from ...session import DataCenter from ...session import DataCenter
from ...session import User as SessionUser from ...session import User as SessionUser
from ...tl import LAYER, Request, abcs, functions, types from ...tl import LAYER, Request, abcs, functions, types
@ -45,6 +44,8 @@ def default_system_version() -> str:
class Config: class Config:
api_id: int api_id: int
api_hash: str api_hash: str
base_logger: logging.Logger
connector: Connector
device_model: str = field(default_factory=default_device_model) device_model: str = field(default_factory=default_device_model)
system_version: str = field(default_factory=default_system_version) system_version: str = field(default_factory=default_system_version)
app_version: str = __version__ app_version: str = __version__
@ -76,7 +77,6 @@ async def connect_sender(
config: Config, config: Config,
known_dcs: List[DataCenter], known_dcs: List[DataCenter],
dc: DataCenter, dc: DataCenter,
base_logger: logging.Logger,
force_auth_gen: bool = False, force_auth_gen: bool = False,
) -> Tuple[Sender, List[DataCenter]]: ) -> Tuple[Sender, List[DataCenter]]:
# Only the ID of the input DC may be known. # Only the ID of the input DC may be known.
@ -93,11 +93,14 @@ async def connect_sender(
or (next((d.auth for d in known_dcs if d.id == dc.id and d.auth), None)) or (next((d.auth for d in known_dcs if d.id == dc.id and d.auth), None))
) )
transport = Full() sender = await do_connect_sender(
if auth: Full(),
sender = await connect_with_auth(transport, dc.id, addr, auth, base_logger) dc.id,
else: addr,
sender = await connect_without_auth(transport, dc.id, addr, base_logger) auth_key=auth,
base_logger=config.base_logger,
connector=config.connector,
)
try: try:
remote_config_data = await sender.invoke( remote_config_data = await sender.invoke(
@ -122,13 +125,11 @@ async def connect_sender(
dc = DataCenter( dc = DataCenter(
id=dc.id, ipv4_addr=dc.ipv4_addr, ipv6_addr=dc.ipv6_addr, auth=None id=dc.id, ipv4_addr=dc.ipv4_addr, ipv6_addr=dc.ipv6_addr, auth=None
) )
base_logger.warning( config.base_logger.warning(
"datacenter could not find stored auth; will retry generating a new one: %s", "datacenter could not find stored auth; will retry generating a new one: %s",
dc, dc,
) )
return await connect_sender( return await connect_sender(config, known_dcs, dc, force_auth_gen=True)
config, known_dcs, dc, base_logger, force_auth_gen=True
)
else: else:
raise raise
@ -177,7 +178,7 @@ async def connect(self: Client) -> None:
id=self._session.user.dc if self._session.user else DEFAULT_DC id=self._session.user.dc if self._session.user else DEFAULT_DC
) )
self._sender, self._session.dcs = await connect_sender( self._sender, self._session.dcs = await connect_sender(
self._config, self._session.dcs, datacenter, self._logger self._config, self._session.dcs, datacenter
) )
if self._message_box.is_empty() and self._session.user: if self._message_box.is_empty() and self._session.user:
@ -216,7 +217,7 @@ async def disconnect(self: Client) -> None:
except asyncio.CancelledError: except asyncio.CancelledError:
pass pass
except Exception: except Exception:
self._logger.exception( self._config.base_logger.exception(
"unhandled exception when cancelling dispatcher; this is a bug" "unhandled exception when cancelling dispatcher; this is a bug"
) )
finally: finally:
@ -225,7 +226,9 @@ async def disconnect(self: Client) -> None:
try: try:
await sender.disconnect() await sender.disconnect()
except Exception: except Exception:
self._logger.exception("unhandled exception during disconnect; this is a bug") self._config.base_logger.exception(
"unhandled exception during disconnect; this is a bug"
)
self._session.state = self._message_box.session_state() self._session.state = self._message_box.session_state()
await self._storage.save(self._session) await self._storage.save(self._session)

View File

@ -116,7 +116,7 @@ def extend_update_queue(
now - client._last_update_limit_warn now - client._last_update_limit_warn
> UPDATE_LIMIT_EXCEEDED_LOG_COOLDOWN > UPDATE_LIMIT_EXCEEDED_LOG_COOLDOWN
): ):
client._logger.warning( client._config.base_logger.warning(
"updates are being dropped because limit=%d has been reached", "updates are being dropped because limit=%d has been reached",
client._updates.maxsize, client._updates.maxsize,
) )
@ -134,13 +134,13 @@ async def dispatcher(client: Client) -> None:
except Exception as e: except Exception as e:
if isinstance(e, RuntimeError) and loop.is_closed(): if isinstance(e, RuntimeError) and loop.is_closed():
# User probably forgot to call disconnect. # User probably forgot to call disconnect.
client._logger.warning( client._config.base_logger.warning(
"client was not closed cleanly, make sure to call client.disconnect()! %s", "client was not closed cleanly, make sure to call client.disconnect()! %s",
e, e,
) )
return return
else: else:
client._logger.exception( client._config.base_logger.exception(
"unhandled exception in event handler; this is probably a bug in your code, not telethon" "unhandled exception in event handler; this is probably a bug in your code, not telethon"
) )
raise raise

View File

@ -2,16 +2,20 @@ from .sender import (
MAXIMUM_DATA, MAXIMUM_DATA,
NO_PING_DISCONNECT, NO_PING_DISCONNECT,
PING_DELAY, PING_DELAY,
AsyncReader,
AsyncWriter,
Connector,
Sender, Sender,
connect, connect,
connect_with_auth,
) )
__all__ = [ __all__ = [
"MAXIMUM_DATA", "MAXIMUM_DATA",
"NO_PING_DISCONNECT", "NO_PING_DISCONNECT",
"PING_DELAY", "PING_DELAY",
"AsyncReader",
"AsyncWriter",
"Connector",
"Sender", "Sender",
"connect", "connect",
"connect_with_auth",
] ]

View File

@ -3,9 +3,9 @@ import logging
import struct import struct
import time import time
from abc import ABC from abc import ABC
from asyncio import FIRST_COMPLETED, Event, Future, StreamReader, StreamWriter from asyncio import FIRST_COMPLETED, Event, Future
from dataclasses import dataclass from dataclasses import dataclass
from typing import Generic, List, Optional, Self, TypeVar from typing import Generic, List, Optional, Protocol, Self, Tuple, TypeVar
from ..crypto import AuthKey from ..crypto import AuthKey
from ..mtproto import ( from ..mtproto import (
@ -43,6 +43,74 @@ def generate_random_id() -> int:
return _last_id return _last_id
class AsyncReader(Protocol):
"""
A :class:`asyncio.StreamReader`-like class.
"""
async def read(self, n: int) -> bytes:
"""
Must behave like :meth:`asyncio.StreamReader.read`.
:param n:
Amount of bytes to read at most.
"""
class AsyncWriter(Protocol):
"""
A :class:`asyncio.StreamWriter`-like class.
"""
def write(self, data: bytes) -> None:
"""
Must behave like :meth:`asyncio.StreamWriter.write`.
:param data:
Data that must be entirely written or buffered until :meth:`drain` is called.
"""
async def drain(self) -> None:
"""
Must behave like :meth:`asyncio.StreamWriter.drain`.
"""
def close(self) -> None:
"""
Must behave like :meth:`asyncio.StreamWriter.close`.
"""
async def wait_closed(self) -> None:
"""
Must behave like :meth:`asyncio.StreamWriter.wait_closed`.
"""
class Connector(Protocol):
"""
A *Connector* is any function that takes in the following two positional parameters as input:
* The ``ip`` address as a :class:`str`. This might be either a IPv4 or IPv6.
* The ``port`` as a :class:`int`. This will be a number below 2¹, often 443.
and returns a :class:`tuple`\ [:class:`AsyncReader`, :class:`AsyncWriter`].
You can use a custom connector to connect to Telegram through proxies.
The library will only ever open remote connections through this function.
The default connector is :func:`asyncio.open_connection`, defined as:
.. code-block:: python
default_connector = lambda ip, port: asyncio.open_connection(ip, port)
If your connector needs additional parameters, you can use either the :keyword:`lambda` syntax or :func:`functools.partial`.
"""
async def __call__(self, ip: str, port: int) -> Tuple[AsyncReader, AsyncWriter]:
pass
class RequestState(ABC): class RequestState(ABC):
pass pass
@ -80,8 +148,8 @@ class Sender:
dc_id: int dc_id: int
addr: str addr: str
_logger: logging.Logger _logger: logging.Logger
_reader: StreamReader _reader: AsyncReader
_writer: StreamWriter _writer: AsyncWriter
_transport: Transport _transport: Transport
_mtp: Mtp _mtp: Mtp
_mtp_buffer: bytearray _mtp_buffer: bytearray
@ -98,9 +166,12 @@ class Sender:
mtp: Mtp, mtp: Mtp,
dc_id: int, dc_id: int,
addr: str, addr: str,
*,
connector: Connector,
base_logger: logging.Logger, base_logger: logging.Logger,
) -> Self: ) -> Self:
reader, writer = await asyncio.open_connection(*addr.split(":")) ip, port = addr.split(":")
reader, writer = await connector(ip, int(port))
return cls( return cls(
dc_id=dc_id, dc_id=dc_id,
@ -299,10 +370,33 @@ class Sender:
async def connect( async def connect(
transport: Transport, dc_id: int, addr: str, base_logger: logging.Logger transport: Transport,
dc_id: int,
addr: str,
*,
auth_key: Optional[bytes],
base_logger: logging.Logger,
connector: Connector,
) -> Sender: ) -> Sender:
sender = await Sender.connect(transport, Plain(), dc_id, addr, base_logger) if auth_key is None:
sender = await Sender.connect(
transport,
Plain(),
dc_id,
addr,
connector=connector,
base_logger=base_logger,
)
return await generate_auth_key(sender) return await generate_auth_key(sender)
else:
return await Sender.connect(
transport,
Encrypted(AuthKey.from_bytes(auth_key)),
dc_id,
addr,
connector=connector,
base_logger=base_logger,
)
async def generate_auth_key(sender: Sender) -> Sender: async def generate_auth_key(sender: Sender) -> Sender:
@ -320,15 +414,3 @@ async def generate_auth_key(sender: Sender) -> Sender:
sender._mtp = Encrypted(auth_key, time_offset=time_offset, first_salt=first_salt) sender._mtp = Encrypted(auth_key, time_offset=time_offset, first_salt=first_salt)
sender._next_ping = asyncio.get_running_loop().time() + PING_DELAY sender._next_ping = asyncio.get_running_loop().time() + PING_DELAY
return sender return sender
async def connect_with_auth(
transport: Transport,
dc_id: int,
addr: str,
auth_key: bytes,
base_logger: logging.Logger,
) -> Sender:
return await Sender.connect(
transport, Encrypted(AuthKey.from_bytes(auth_key)), dc_id, addr, base_logger
)

View File

@ -21,7 +21,13 @@ async def test_invoke_encrypted_method(caplog: LogCaptureFixture) -> None:
return deadline - asyncio.get_running_loop().time() return deadline - asyncio.get_running_loop().time()
sender = await asyncio.wait_for( sender = await asyncio.wait_for(
connect(Full(), *TELEGRAM_TEST_DC, logging.getLogger(__file__)), connect(
Full(),
*TELEGRAM_TEST_DC,
auth_key=None,
base_logger=logging.getLogger(__file__),
connector=lambda ip, port: asyncio.open_connection(ip, port),
),
timeout(), timeout(),
) )