mirror of
https://github.com/LonamiWebs/Telethon.git
synced 2025-02-10 08:30:52 +03:00
Support custom connectors
This commit is contained in:
parent
d80c6b3bb4
commit
6e88264b28
|
@ -107,3 +107,11 @@ Private definitions
|
||||||
.. autoclass:: InFileLike
|
.. autoclass:: InFileLike
|
||||||
|
|
||||||
.. autoclass:: OutFileLike
|
.. autoclass:: OutFileLike
|
||||||
|
|
||||||
|
.. currentmodule:: telethon._impl.mtsender.sender
|
||||||
|
|
||||||
|
.. autoclass:: AsyncReader
|
||||||
|
|
||||||
|
.. autoclass:: AsyncWriter
|
||||||
|
|
||||||
|
.. autoclass:: Connector
|
||||||
|
|
|
@ -49,7 +49,7 @@ async def complete_login(client: Client, auth: abcs.auth.Authorization) -> User:
|
||||||
try:
|
try:
|
||||||
await client._storage.save(client._session)
|
await client._storage.save(client._session)
|
||||||
except Exception:
|
except Exception:
|
||||||
client._logger.exception(
|
client._config.base_logger.exception(
|
||||||
"failed to save session upon login; you may need to login again in future runs"
|
"failed to save session upon login; you may need to login again in future runs"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -59,7 +59,7 @@ async def complete_login(client: Client, auth: abcs.auth.Authorization) -> User:
|
||||||
async def handle_migrate(client: Client, dc_id: Optional[int]) -> None:
|
async def handle_migrate(client: Client, dc_id: Optional[int]) -> None:
|
||||||
assert dc_id is not None
|
assert dc_id is not None
|
||||||
sender, client._session.dcs = await connect_sender(
|
sender, client._session.dcs = await connect_sender(
|
||||||
client._config, client._session.dcs, DataCenter(id=dc_id), client._logger
|
client._config, client._session.dcs, DataCenter(id=dc_id)
|
||||||
)
|
)
|
||||||
async with client._sender_lock:
|
async with client._sender_lock:
|
||||||
client._sender = sender
|
client._sender = sender
|
||||||
|
|
|
@ -19,12 +19,11 @@ from typing import (
|
||||||
Union,
|
Union,
|
||||||
)
|
)
|
||||||
|
|
||||||
from telethon._impl.session.session import DataCenter
|
|
||||||
|
|
||||||
from ....version import __version__ as default_version
|
from ....version import __version__ as default_version
|
||||||
from ...mtsender import Sender
|
from ...mtsender import Connector, Sender
|
||||||
from ...session import (
|
from ...session import (
|
||||||
ChatHashCache,
|
ChatHashCache,
|
||||||
|
DataCenter,
|
||||||
MemorySession,
|
MemorySession,
|
||||||
MessageBox,
|
MessageBox,
|
||||||
PackedChat,
|
PackedChat,
|
||||||
|
@ -193,6 +192,13 @@ class Client:
|
||||||
:param datacenter:
|
:param datacenter:
|
||||||
Override the datacenter to connect to.
|
Override the datacenter to connect to.
|
||||||
Useful to connect to one of Telegram's test servers.
|
Useful to connect to one of Telegram's test servers.
|
||||||
|
|
||||||
|
:param connector:
|
||||||
|
Asynchronous function called to connect to a remote address.
|
||||||
|
By default, this is :func:`asyncio.open_connection`.
|
||||||
|
In order to use proxies, you can set a custom connector.
|
||||||
|
|
||||||
|
See :class:`~telethon._impl.mtsender.sender.Connector` for more details.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
|
@ -212,10 +218,9 @@ class Client:
|
||||||
system_lang_code: Optional[str] = None,
|
system_lang_code: Optional[str] = None,
|
||||||
lang_code: Optional[str] = None,
|
lang_code: Optional[str] = None,
|
||||||
datacenter: Optional[DataCenter] = None,
|
datacenter: Optional[DataCenter] = None,
|
||||||
|
connector: Optional[Connector] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
self._logger = logger or logging.getLogger(
|
base_logger = logger or logging.getLogger(__package__[: __package__.index(".")])
|
||||||
__package__[: __package__.index(".")]
|
|
||||||
)
|
|
||||||
|
|
||||||
self._sender: Optional[Sender] = None
|
self._sender: Optional[Sender] = None
|
||||||
self._sender_lock = asyncio.Lock()
|
self._sender_lock = asyncio.Lock()
|
||||||
|
@ -240,11 +245,13 @@ class Client:
|
||||||
if flood_sleep_threshold is None
|
if flood_sleep_threshold is None
|
||||||
else flood_sleep_threshold,
|
else flood_sleep_threshold,
|
||||||
update_queue_limit=update_queue_limit,
|
update_queue_limit=update_queue_limit,
|
||||||
|
base_logger=base_logger,
|
||||||
|
connector=connector or (lambda ip, port: asyncio.open_connection(ip, port)),
|
||||||
)
|
)
|
||||||
|
|
||||||
self._session = Session()
|
self._session = Session()
|
||||||
|
|
||||||
self._message_box = MessageBox(base_logger=self._logger)
|
self._message_box = MessageBox(base_logger=base_logger)
|
||||||
self._chat_hashes = ChatHashCache(None)
|
self._chat_hashes = ChatHashCache(None)
|
||||||
self._last_update_limit_warn: Optional[float] = None
|
self._last_update_limit_warn: Optional[float] = None
|
||||||
self._updates: asyncio.Queue[
|
self._updates: asyncio.Queue[
|
||||||
|
|
|
@ -10,9 +10,8 @@ from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, TypeVar
|
||||||
|
|
||||||
from ....version import __version__
|
from ....version import __version__
|
||||||
from ...mtproto import BadStatus, Full, RpcError
|
from ...mtproto import BadStatus, Full, RpcError
|
||||||
from ...mtsender import Sender
|
from ...mtsender import Connector, Sender
|
||||||
from ...mtsender import connect as connect_without_auth
|
from ...mtsender import connect as do_connect_sender
|
||||||
from ...mtsender import connect_with_auth
|
|
||||||
from ...session import DataCenter
|
from ...session import DataCenter
|
||||||
from ...session import User as SessionUser
|
from ...session import User as SessionUser
|
||||||
from ...tl import LAYER, Request, abcs, functions, types
|
from ...tl import LAYER, Request, abcs, functions, types
|
||||||
|
@ -45,6 +44,8 @@ def default_system_version() -> str:
|
||||||
class Config:
|
class Config:
|
||||||
api_id: int
|
api_id: int
|
||||||
api_hash: str
|
api_hash: str
|
||||||
|
base_logger: logging.Logger
|
||||||
|
connector: Connector
|
||||||
device_model: str = field(default_factory=default_device_model)
|
device_model: str = field(default_factory=default_device_model)
|
||||||
system_version: str = field(default_factory=default_system_version)
|
system_version: str = field(default_factory=default_system_version)
|
||||||
app_version: str = __version__
|
app_version: str = __version__
|
||||||
|
@ -76,7 +77,6 @@ async def connect_sender(
|
||||||
config: Config,
|
config: Config,
|
||||||
known_dcs: List[DataCenter],
|
known_dcs: List[DataCenter],
|
||||||
dc: DataCenter,
|
dc: DataCenter,
|
||||||
base_logger: logging.Logger,
|
|
||||||
force_auth_gen: bool = False,
|
force_auth_gen: bool = False,
|
||||||
) -> Tuple[Sender, List[DataCenter]]:
|
) -> Tuple[Sender, List[DataCenter]]:
|
||||||
# Only the ID of the input DC may be known.
|
# Only the ID of the input DC may be known.
|
||||||
|
@ -93,11 +93,14 @@ async def connect_sender(
|
||||||
or (next((d.auth for d in known_dcs if d.id == dc.id and d.auth), None))
|
or (next((d.auth for d in known_dcs if d.id == dc.id and d.auth), None))
|
||||||
)
|
)
|
||||||
|
|
||||||
transport = Full()
|
sender = await do_connect_sender(
|
||||||
if auth:
|
Full(),
|
||||||
sender = await connect_with_auth(transport, dc.id, addr, auth, base_logger)
|
dc.id,
|
||||||
else:
|
addr,
|
||||||
sender = await connect_without_auth(transport, dc.id, addr, base_logger)
|
auth_key=auth,
|
||||||
|
base_logger=config.base_logger,
|
||||||
|
connector=config.connector,
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
remote_config_data = await sender.invoke(
|
remote_config_data = await sender.invoke(
|
||||||
|
@ -122,13 +125,11 @@ async def connect_sender(
|
||||||
dc = DataCenter(
|
dc = DataCenter(
|
||||||
id=dc.id, ipv4_addr=dc.ipv4_addr, ipv6_addr=dc.ipv6_addr, auth=None
|
id=dc.id, ipv4_addr=dc.ipv4_addr, ipv6_addr=dc.ipv6_addr, auth=None
|
||||||
)
|
)
|
||||||
base_logger.warning(
|
config.base_logger.warning(
|
||||||
"datacenter could not find stored auth; will retry generating a new one: %s",
|
"datacenter could not find stored auth; will retry generating a new one: %s",
|
||||||
dc,
|
dc,
|
||||||
)
|
)
|
||||||
return await connect_sender(
|
return await connect_sender(config, known_dcs, dc, force_auth_gen=True)
|
||||||
config, known_dcs, dc, base_logger, force_auth_gen=True
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
@ -177,7 +178,7 @@ async def connect(self: Client) -> None:
|
||||||
id=self._session.user.dc if self._session.user else DEFAULT_DC
|
id=self._session.user.dc if self._session.user else DEFAULT_DC
|
||||||
)
|
)
|
||||||
self._sender, self._session.dcs = await connect_sender(
|
self._sender, self._session.dcs = await connect_sender(
|
||||||
self._config, self._session.dcs, datacenter, self._logger
|
self._config, self._session.dcs, datacenter
|
||||||
)
|
)
|
||||||
|
|
||||||
if self._message_box.is_empty() and self._session.user:
|
if self._message_box.is_empty() and self._session.user:
|
||||||
|
@ -216,7 +217,7 @@ async def disconnect(self: Client) -> None:
|
||||||
except asyncio.CancelledError:
|
except asyncio.CancelledError:
|
||||||
pass
|
pass
|
||||||
except Exception:
|
except Exception:
|
||||||
self._logger.exception(
|
self._config.base_logger.exception(
|
||||||
"unhandled exception when cancelling dispatcher; this is a bug"
|
"unhandled exception when cancelling dispatcher; this is a bug"
|
||||||
)
|
)
|
||||||
finally:
|
finally:
|
||||||
|
@ -225,7 +226,9 @@ async def disconnect(self: Client) -> None:
|
||||||
try:
|
try:
|
||||||
await sender.disconnect()
|
await sender.disconnect()
|
||||||
except Exception:
|
except Exception:
|
||||||
self._logger.exception("unhandled exception during disconnect; this is a bug")
|
self._config.base_logger.exception(
|
||||||
|
"unhandled exception during disconnect; this is a bug"
|
||||||
|
)
|
||||||
|
|
||||||
self._session.state = self._message_box.session_state()
|
self._session.state = self._message_box.session_state()
|
||||||
await self._storage.save(self._session)
|
await self._storage.save(self._session)
|
||||||
|
|
|
@ -116,7 +116,7 @@ def extend_update_queue(
|
||||||
now - client._last_update_limit_warn
|
now - client._last_update_limit_warn
|
||||||
> UPDATE_LIMIT_EXCEEDED_LOG_COOLDOWN
|
> UPDATE_LIMIT_EXCEEDED_LOG_COOLDOWN
|
||||||
):
|
):
|
||||||
client._logger.warning(
|
client._config.base_logger.warning(
|
||||||
"updates are being dropped because limit=%d has been reached",
|
"updates are being dropped because limit=%d has been reached",
|
||||||
client._updates.maxsize,
|
client._updates.maxsize,
|
||||||
)
|
)
|
||||||
|
@ -134,13 +134,13 @@ async def dispatcher(client: Client) -> None:
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if isinstance(e, RuntimeError) and loop.is_closed():
|
if isinstance(e, RuntimeError) and loop.is_closed():
|
||||||
# User probably forgot to call disconnect.
|
# User probably forgot to call disconnect.
|
||||||
client._logger.warning(
|
client._config.base_logger.warning(
|
||||||
"client was not closed cleanly, make sure to call client.disconnect()! %s",
|
"client was not closed cleanly, make sure to call client.disconnect()! %s",
|
||||||
e,
|
e,
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
else:
|
else:
|
||||||
client._logger.exception(
|
client._config.base_logger.exception(
|
||||||
"unhandled exception in event handler; this is probably a bug in your code, not telethon"
|
"unhandled exception in event handler; this is probably a bug in your code, not telethon"
|
||||||
)
|
)
|
||||||
raise
|
raise
|
||||||
|
|
|
@ -2,16 +2,20 @@ from .sender import (
|
||||||
MAXIMUM_DATA,
|
MAXIMUM_DATA,
|
||||||
NO_PING_DISCONNECT,
|
NO_PING_DISCONNECT,
|
||||||
PING_DELAY,
|
PING_DELAY,
|
||||||
|
AsyncReader,
|
||||||
|
AsyncWriter,
|
||||||
|
Connector,
|
||||||
Sender,
|
Sender,
|
||||||
connect,
|
connect,
|
||||||
connect_with_auth,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"MAXIMUM_DATA",
|
"MAXIMUM_DATA",
|
||||||
"NO_PING_DISCONNECT",
|
"NO_PING_DISCONNECT",
|
||||||
"PING_DELAY",
|
"PING_DELAY",
|
||||||
|
"AsyncReader",
|
||||||
|
"AsyncWriter",
|
||||||
|
"Connector",
|
||||||
"Sender",
|
"Sender",
|
||||||
"connect",
|
"connect",
|
||||||
"connect_with_auth",
|
|
||||||
]
|
]
|
||||||
|
|
|
@ -3,9 +3,9 @@ import logging
|
||||||
import struct
|
import struct
|
||||||
import time
|
import time
|
||||||
from abc import ABC
|
from abc import ABC
|
||||||
from asyncio import FIRST_COMPLETED, Event, Future, StreamReader, StreamWriter
|
from asyncio import FIRST_COMPLETED, Event, Future
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Generic, List, Optional, Self, TypeVar
|
from typing import Generic, List, Optional, Protocol, Self, Tuple, TypeVar
|
||||||
|
|
||||||
from ..crypto import AuthKey
|
from ..crypto import AuthKey
|
||||||
from ..mtproto import (
|
from ..mtproto import (
|
||||||
|
@ -43,6 +43,74 @@ def generate_random_id() -> int:
|
||||||
return _last_id
|
return _last_id
|
||||||
|
|
||||||
|
|
||||||
|
class AsyncReader(Protocol):
|
||||||
|
"""
|
||||||
|
A :class:`asyncio.StreamReader`-like class.
|
||||||
|
"""
|
||||||
|
|
||||||
|
async def read(self, n: int) -> bytes:
|
||||||
|
"""
|
||||||
|
Must behave like :meth:`asyncio.StreamReader.read`.
|
||||||
|
|
||||||
|
:param n:
|
||||||
|
Amount of bytes to read at most.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
class AsyncWriter(Protocol):
|
||||||
|
"""
|
||||||
|
A :class:`asyncio.StreamWriter`-like class.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def write(self, data: bytes) -> None:
|
||||||
|
"""
|
||||||
|
Must behave like :meth:`asyncio.StreamWriter.write`.
|
||||||
|
|
||||||
|
:param data:
|
||||||
|
Data that must be entirely written or buffered until :meth:`drain` is called.
|
||||||
|
"""
|
||||||
|
|
||||||
|
async def drain(self) -> None:
|
||||||
|
"""
|
||||||
|
Must behave like :meth:`asyncio.StreamWriter.drain`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def close(self) -> None:
|
||||||
|
"""
|
||||||
|
Must behave like :meth:`asyncio.StreamWriter.close`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
async def wait_closed(self) -> None:
|
||||||
|
"""
|
||||||
|
Must behave like :meth:`asyncio.StreamWriter.wait_closed`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
class Connector(Protocol):
|
||||||
|
"""
|
||||||
|
A *Connector* is any function that takes in the following two positional parameters as input:
|
||||||
|
|
||||||
|
* The ``ip`` address as a :class:`str`. This might be either a IPv4 or IPv6.
|
||||||
|
* The ``port`` as a :class:`int`. This will be a number below 2¹⁶, often 443.
|
||||||
|
|
||||||
|
and returns a :class:`tuple`\ [:class:`AsyncReader`, :class:`AsyncWriter`].
|
||||||
|
|
||||||
|
You can use a custom connector to connect to Telegram through proxies.
|
||||||
|
The library will only ever open remote connections through this function.
|
||||||
|
|
||||||
|
The default connector is :func:`asyncio.open_connection`, defined as:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
default_connector = lambda ip, port: asyncio.open_connection(ip, port)
|
||||||
|
|
||||||
|
If your connector needs additional parameters, you can use either the :keyword:`lambda` syntax or :func:`functools.partial`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
async def __call__(self, ip: str, port: int) -> Tuple[AsyncReader, AsyncWriter]:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class RequestState(ABC):
|
class RequestState(ABC):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@ -80,8 +148,8 @@ class Sender:
|
||||||
dc_id: int
|
dc_id: int
|
||||||
addr: str
|
addr: str
|
||||||
_logger: logging.Logger
|
_logger: logging.Logger
|
||||||
_reader: StreamReader
|
_reader: AsyncReader
|
||||||
_writer: StreamWriter
|
_writer: AsyncWriter
|
||||||
_transport: Transport
|
_transport: Transport
|
||||||
_mtp: Mtp
|
_mtp: Mtp
|
||||||
_mtp_buffer: bytearray
|
_mtp_buffer: bytearray
|
||||||
|
@ -98,9 +166,12 @@ class Sender:
|
||||||
mtp: Mtp,
|
mtp: Mtp,
|
||||||
dc_id: int,
|
dc_id: int,
|
||||||
addr: str,
|
addr: str,
|
||||||
|
*,
|
||||||
|
connector: Connector,
|
||||||
base_logger: logging.Logger,
|
base_logger: logging.Logger,
|
||||||
) -> Self:
|
) -> Self:
|
||||||
reader, writer = await asyncio.open_connection(*addr.split(":"))
|
ip, port = addr.split(":")
|
||||||
|
reader, writer = await connector(ip, int(port))
|
||||||
|
|
||||||
return cls(
|
return cls(
|
||||||
dc_id=dc_id,
|
dc_id=dc_id,
|
||||||
|
@ -299,10 +370,33 @@ class Sender:
|
||||||
|
|
||||||
|
|
||||||
async def connect(
|
async def connect(
|
||||||
transport: Transport, dc_id: int, addr: str, base_logger: logging.Logger
|
transport: Transport,
|
||||||
|
dc_id: int,
|
||||||
|
addr: str,
|
||||||
|
*,
|
||||||
|
auth_key: Optional[bytes],
|
||||||
|
base_logger: logging.Logger,
|
||||||
|
connector: Connector,
|
||||||
) -> Sender:
|
) -> Sender:
|
||||||
sender = await Sender.connect(transport, Plain(), dc_id, addr, base_logger)
|
if auth_key is None:
|
||||||
|
sender = await Sender.connect(
|
||||||
|
transport,
|
||||||
|
Plain(),
|
||||||
|
dc_id,
|
||||||
|
addr,
|
||||||
|
connector=connector,
|
||||||
|
base_logger=base_logger,
|
||||||
|
)
|
||||||
return await generate_auth_key(sender)
|
return await generate_auth_key(sender)
|
||||||
|
else:
|
||||||
|
return await Sender.connect(
|
||||||
|
transport,
|
||||||
|
Encrypted(AuthKey.from_bytes(auth_key)),
|
||||||
|
dc_id,
|
||||||
|
addr,
|
||||||
|
connector=connector,
|
||||||
|
base_logger=base_logger,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
async def generate_auth_key(sender: Sender) -> Sender:
|
async def generate_auth_key(sender: Sender) -> Sender:
|
||||||
|
@ -320,15 +414,3 @@ async def generate_auth_key(sender: Sender) -> Sender:
|
||||||
sender._mtp = Encrypted(auth_key, time_offset=time_offset, first_salt=first_salt)
|
sender._mtp = Encrypted(auth_key, time_offset=time_offset, first_salt=first_salt)
|
||||||
sender._next_ping = asyncio.get_running_loop().time() + PING_DELAY
|
sender._next_ping = asyncio.get_running_loop().time() + PING_DELAY
|
||||||
return sender
|
return sender
|
||||||
|
|
||||||
|
|
||||||
async def connect_with_auth(
|
|
||||||
transport: Transport,
|
|
||||||
dc_id: int,
|
|
||||||
addr: str,
|
|
||||||
auth_key: bytes,
|
|
||||||
base_logger: logging.Logger,
|
|
||||||
) -> Sender:
|
|
||||||
return await Sender.connect(
|
|
||||||
transport, Encrypted(AuthKey.from_bytes(auth_key)), dc_id, addr, base_logger
|
|
||||||
)
|
|
||||||
|
|
|
@ -21,7 +21,13 @@ async def test_invoke_encrypted_method(caplog: LogCaptureFixture) -> None:
|
||||||
return deadline - asyncio.get_running_loop().time()
|
return deadline - asyncio.get_running_loop().time()
|
||||||
|
|
||||||
sender = await asyncio.wait_for(
|
sender = await asyncio.wait_for(
|
||||||
connect(Full(), *TELEGRAM_TEST_DC, logging.getLogger(__file__)),
|
connect(
|
||||||
|
Full(),
|
||||||
|
*TELEGRAM_TEST_DC,
|
||||||
|
auth_key=None,
|
||||||
|
base_logger=logging.getLogger(__file__),
|
||||||
|
connector=lambda ip, port: asyncio.open_connection(ip, port),
|
||||||
|
),
|
||||||
timeout(),
|
timeout(),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user