diff --git a/client/doc/modules/types.rst b/client/doc/modules/types.rst index caa97617..18ac0c3e 100644 --- a/client/doc/modules/types.rst +++ b/client/doc/modules/types.rst @@ -107,3 +107,11 @@ Private definitions .. autoclass:: InFileLike .. autoclass:: OutFileLike + +.. currentmodule:: telethon._impl.mtsender.sender + +.. autoclass:: AsyncReader + +.. autoclass:: AsyncWriter + +.. autoclass:: Connector diff --git a/client/src/telethon/_impl/client/client/auth.py b/client/src/telethon/_impl/client/client/auth.py index 0056dbbb..59444208 100644 --- a/client/src/telethon/_impl/client/client/auth.py +++ b/client/src/telethon/_impl/client/client/auth.py @@ -49,7 +49,7 @@ async def complete_login(client: Client, auth: abcs.auth.Authorization) -> User: try: await client._storage.save(client._session) except Exception: - client._logger.exception( + client._config.base_logger.exception( "failed to save session upon login; you may need to login again in future runs" ) @@ -59,7 +59,7 @@ async def complete_login(client: Client, auth: abcs.auth.Authorization) -> User: async def handle_migrate(client: Client, dc_id: Optional[int]) -> None: assert dc_id is not None sender, client._session.dcs = await connect_sender( - client._config, client._session.dcs, DataCenter(id=dc_id), client._logger + client._config, client._session.dcs, DataCenter(id=dc_id) ) async with client._sender_lock: client._sender = sender diff --git a/client/src/telethon/_impl/client/client/client.py b/client/src/telethon/_impl/client/client/client.py index 0de0a55b..521e7181 100644 --- a/client/src/telethon/_impl/client/client/client.py +++ b/client/src/telethon/_impl/client/client/client.py @@ -19,12 +19,11 @@ from typing import ( Union, ) -from telethon._impl.session.session import DataCenter - from ....version import __version__ as default_version -from ...mtsender import Sender +from ...mtsender import Connector, Sender from ...session import ( ChatHashCache, + DataCenter, MemorySession, MessageBox, PackedChat, @@ -193,6 +192,13 @@ class Client: :param datacenter: Override the datacenter to connect to. Useful to connect to one of Telegram's test servers. + + :param connector: + Asynchronous function called to connect to a remote address. + By default, this is :func:`asyncio.open_connection`. + In order to use proxies, you can set a custom connector. + + See :class:`~telethon._impl.mtsender.sender.Connector` for more details. """ def __init__( @@ -212,10 +218,9 @@ class Client: system_lang_code: Optional[str] = None, lang_code: Optional[str] = None, datacenter: Optional[DataCenter] = None, + connector: Optional[Connector] = None, ) -> None: - self._logger = logger or logging.getLogger( - __package__[: __package__.index(".")] - ) + base_logger = logger or logging.getLogger(__package__[: __package__.index(".")]) self._sender: Optional[Sender] = None self._sender_lock = asyncio.Lock() @@ -240,11 +245,13 @@ class Client: if flood_sleep_threshold is None else flood_sleep_threshold, update_queue_limit=update_queue_limit, + base_logger=base_logger, + connector=connector or (lambda ip, port: asyncio.open_connection(ip, port)), ) self._session = Session() - self._message_box = MessageBox(base_logger=self._logger) + self._message_box = MessageBox(base_logger=base_logger) self._chat_hashes = ChatHashCache(None) self._last_update_limit_warn: Optional[float] = None self._updates: asyncio.Queue[ diff --git a/client/src/telethon/_impl/client/client/net.py b/client/src/telethon/_impl/client/client/net.py index 0eaec9d5..2ac8b9c0 100644 --- a/client/src/telethon/_impl/client/client/net.py +++ b/client/src/telethon/_impl/client/client/net.py @@ -10,9 +10,8 @@ from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, TypeVar from ....version import __version__ from ...mtproto import BadStatus, Full, RpcError -from ...mtsender import Sender -from ...mtsender import connect as connect_without_auth -from ...mtsender import connect_with_auth +from ...mtsender import Connector, Sender +from ...mtsender import connect as do_connect_sender from ...session import DataCenter from ...session import User as SessionUser from ...tl import LAYER, Request, abcs, functions, types @@ -45,6 +44,8 @@ def default_system_version() -> str: class Config: api_id: int api_hash: str + base_logger: logging.Logger + connector: Connector device_model: str = field(default_factory=default_device_model) system_version: str = field(default_factory=default_system_version) app_version: str = __version__ @@ -76,7 +77,6 @@ async def connect_sender( config: Config, known_dcs: List[DataCenter], dc: DataCenter, - base_logger: logging.Logger, force_auth_gen: bool = False, ) -> Tuple[Sender, List[DataCenter]]: # Only the ID of the input DC may be known. @@ -93,11 +93,14 @@ async def connect_sender( or (next((d.auth for d in known_dcs if d.id == dc.id and d.auth), None)) ) - transport = Full() - if auth: - sender = await connect_with_auth(transport, dc.id, addr, auth, base_logger) - else: - sender = await connect_without_auth(transport, dc.id, addr, base_logger) + sender = await do_connect_sender( + Full(), + dc.id, + addr, + auth_key=auth, + base_logger=config.base_logger, + connector=config.connector, + ) try: remote_config_data = await sender.invoke( @@ -122,13 +125,11 @@ async def connect_sender( dc = DataCenter( id=dc.id, ipv4_addr=dc.ipv4_addr, ipv6_addr=dc.ipv6_addr, auth=None ) - base_logger.warning( + config.base_logger.warning( "datacenter could not find stored auth; will retry generating a new one: %s", dc, ) - return await connect_sender( - config, known_dcs, dc, base_logger, force_auth_gen=True - ) + return await connect_sender(config, known_dcs, dc, force_auth_gen=True) else: raise @@ -177,7 +178,7 @@ async def connect(self: Client) -> None: id=self._session.user.dc if self._session.user else DEFAULT_DC ) self._sender, self._session.dcs = await connect_sender( - self._config, self._session.dcs, datacenter, self._logger + self._config, self._session.dcs, datacenter ) if self._message_box.is_empty() and self._session.user: @@ -216,7 +217,7 @@ async def disconnect(self: Client) -> None: except asyncio.CancelledError: pass except Exception: - self._logger.exception( + self._config.base_logger.exception( "unhandled exception when cancelling dispatcher; this is a bug" ) finally: @@ -225,7 +226,9 @@ async def disconnect(self: Client) -> None: try: await sender.disconnect() except Exception: - self._logger.exception("unhandled exception during disconnect; this is a bug") + self._config.base_logger.exception( + "unhandled exception during disconnect; this is a bug" + ) self._session.state = self._message_box.session_state() await self._storage.save(self._session) diff --git a/client/src/telethon/_impl/client/client/updates.py b/client/src/telethon/_impl/client/client/updates.py index 5a94ca92..b2422327 100644 --- a/client/src/telethon/_impl/client/client/updates.py +++ b/client/src/telethon/_impl/client/client/updates.py @@ -116,7 +116,7 @@ def extend_update_queue( now - client._last_update_limit_warn > UPDATE_LIMIT_EXCEEDED_LOG_COOLDOWN ): - client._logger.warning( + client._config.base_logger.warning( "updates are being dropped because limit=%d has been reached", client._updates.maxsize, ) @@ -134,13 +134,13 @@ async def dispatcher(client: Client) -> None: except Exception as e: if isinstance(e, RuntimeError) and loop.is_closed(): # User probably forgot to call disconnect. - client._logger.warning( + client._config.base_logger.warning( "client was not closed cleanly, make sure to call client.disconnect()! %s", e, ) return else: - client._logger.exception( + client._config.base_logger.exception( "unhandled exception in event handler; this is probably a bug in your code, not telethon" ) raise diff --git a/client/src/telethon/_impl/mtsender/__init__.py b/client/src/telethon/_impl/mtsender/__init__.py index 5b2c5ed6..bc9d723f 100644 --- a/client/src/telethon/_impl/mtsender/__init__.py +++ b/client/src/telethon/_impl/mtsender/__init__.py @@ -2,16 +2,20 @@ from .sender import ( MAXIMUM_DATA, NO_PING_DISCONNECT, PING_DELAY, + AsyncReader, + AsyncWriter, + Connector, Sender, connect, - connect_with_auth, ) __all__ = [ "MAXIMUM_DATA", "NO_PING_DISCONNECT", "PING_DELAY", + "AsyncReader", + "AsyncWriter", + "Connector", "Sender", "connect", - "connect_with_auth", ] diff --git a/client/src/telethon/_impl/mtsender/sender.py b/client/src/telethon/_impl/mtsender/sender.py index 52b7a8ee..c2ee7e8e 100644 --- a/client/src/telethon/_impl/mtsender/sender.py +++ b/client/src/telethon/_impl/mtsender/sender.py @@ -3,9 +3,9 @@ import logging import struct import time from abc import ABC -from asyncio import FIRST_COMPLETED, Event, Future, StreamReader, StreamWriter +from asyncio import FIRST_COMPLETED, Event, Future from dataclasses import dataclass -from typing import Generic, List, Optional, Self, TypeVar +from typing import Generic, List, Optional, Protocol, Self, Tuple, TypeVar from ..crypto import AuthKey from ..mtproto import ( @@ -43,6 +43,74 @@ def generate_random_id() -> int: return _last_id +class AsyncReader(Protocol): + """ + A :class:`asyncio.StreamReader`-like class. + """ + + async def read(self, n: int) -> bytes: + """ + Must behave like :meth:`asyncio.StreamReader.read`. + + :param n: + Amount of bytes to read at most. + """ + + +class AsyncWriter(Protocol): + """ + A :class:`asyncio.StreamWriter`-like class. + """ + + def write(self, data: bytes) -> None: + """ + Must behave like :meth:`asyncio.StreamWriter.write`. + + :param data: + Data that must be entirely written or buffered until :meth:`drain` is called. + """ + + async def drain(self) -> None: + """ + Must behave like :meth:`asyncio.StreamWriter.drain`. + """ + + def close(self) -> None: + """ + Must behave like :meth:`asyncio.StreamWriter.close`. + """ + + async def wait_closed(self) -> None: + """ + Must behave like :meth:`asyncio.StreamWriter.wait_closed`. + """ + + +class Connector(Protocol): + """ + A *Connector* is any function that takes in the following two positional parameters as input: + + * The ``ip`` address as a :class:`str`. This might be either a IPv4 or IPv6. + * The ``port`` as a :class:`int`. This will be a number below 2¹⁶, often 443. + + and returns a :class:`tuple`\ [:class:`AsyncReader`, :class:`AsyncWriter`]. + + You can use a custom connector to connect to Telegram through proxies. + The library will only ever open remote connections through this function. + + The default connector is :func:`asyncio.open_connection`, defined as: + + .. code-block:: python + + default_connector = lambda ip, port: asyncio.open_connection(ip, port) + + If your connector needs additional parameters, you can use either the :keyword:`lambda` syntax or :func:`functools.partial`. + """ + + async def __call__(self, ip: str, port: int) -> Tuple[AsyncReader, AsyncWriter]: + pass + + class RequestState(ABC): pass @@ -80,8 +148,8 @@ class Sender: dc_id: int addr: str _logger: logging.Logger - _reader: StreamReader - _writer: StreamWriter + _reader: AsyncReader + _writer: AsyncWriter _transport: Transport _mtp: Mtp _mtp_buffer: bytearray @@ -98,9 +166,12 @@ class Sender: mtp: Mtp, dc_id: int, addr: str, + *, + connector: Connector, base_logger: logging.Logger, ) -> Self: - reader, writer = await asyncio.open_connection(*addr.split(":")) + ip, port = addr.split(":") + reader, writer = await connector(ip, int(port)) return cls( dc_id=dc_id, @@ -299,10 +370,33 @@ class Sender: async def connect( - transport: Transport, dc_id: int, addr: str, base_logger: logging.Logger + transport: Transport, + dc_id: int, + addr: str, + *, + auth_key: Optional[bytes], + base_logger: logging.Logger, + connector: Connector, ) -> Sender: - sender = await Sender.connect(transport, Plain(), dc_id, addr, base_logger) - return await generate_auth_key(sender) + if auth_key is None: + sender = await Sender.connect( + transport, + Plain(), + dc_id, + addr, + connector=connector, + base_logger=base_logger, + ) + return await generate_auth_key(sender) + else: + return await Sender.connect( + transport, + Encrypted(AuthKey.from_bytes(auth_key)), + dc_id, + addr, + connector=connector, + base_logger=base_logger, + ) async def generate_auth_key(sender: Sender) -> Sender: @@ -320,15 +414,3 @@ async def generate_auth_key(sender: Sender) -> Sender: sender._mtp = Encrypted(auth_key, time_offset=time_offset, first_salt=first_salt) sender._next_ping = asyncio.get_running_loop().time() + PING_DELAY return sender - - -async def connect_with_auth( - transport: Transport, - dc_id: int, - addr: str, - auth_key: bytes, - base_logger: logging.Logger, -) -> Sender: - return await Sender.connect( - transport, Encrypted(AuthKey.from_bytes(auth_key)), dc_id, addr, base_logger - ) diff --git a/client/tests/mtsender_test.py b/client/tests/mtsender_test.py index cef8466e..93ab7cee 100644 --- a/client/tests/mtsender_test.py +++ b/client/tests/mtsender_test.py @@ -21,7 +21,13 @@ async def test_invoke_encrypted_method(caplog: LogCaptureFixture) -> None: return deadline - asyncio.get_running_loop().time() sender = await asyncio.wait_for( - connect(Full(), *TELEGRAM_TEST_DC, logging.getLogger(__file__)), + connect( + Full(), + *TELEGRAM_TEST_DC, + auth_key=None, + base_logger=logging.getLogger(__file__), + connector=lambda ip, port: asyncio.open_connection(ip, port), + ), timeout(), )