diff --git a/client/src/telethon/_impl/client/client/client.py b/client/src/telethon/_impl/client/client/client.py index c068d796..4f8fe04d 100644 --- a/client/src/telethon/_impl/client/client/client.py +++ b/client/src/telethon/_impl/client/client/client.py @@ -1,6 +1,13 @@ +import asyncio +from collections import deque from types import TracebackType -from typing import Optional, Type +from typing import Deque, Optional, Self, Type, TypeVar +from ...mtsender.sender import Sender +from ...session.chat.hash_cache import ChatHashCache +from ...session.message_box.messagebox import MessageBox +from ...tl import abcs +from ...tl.core.request import Request from .account import edit_2fa, end_takeout, takeout from .auth import log_out, qr_login, send_code_request, sign_in, sign_up, start from .bots import inline_query @@ -29,13 +36,14 @@ from .messages import ( unpin_message, ) from .net import ( + DEFAULT_DC, + Config, connect, + connected, disconnect, - disconnected, - flood_sleep_threshold, - is_connected, - loop, - set_proxy, + invoke_request, + run_until_disconnected, + step, ) from .updates import ( add_event_handler, @@ -43,7 +51,6 @@ from .updates import ( list_event_handlers, on, remove_event_handler, - run_until_disconnected, set_receive_updates, ) from .uploads import send_file, upload_file @@ -56,8 +63,26 @@ from .users import ( is_user_authorized, ) +Return = TypeVar("Return") + class Client: + def __init__(self, config: Config) -> None: + self._sender: Optional[Sender] = None + self._sender_lock = asyncio.Lock() + self._dc_id = DEFAULT_DC + self._config = config + self._message_box = MessageBox() + self._chat_hashes = ChatHashCache(None) + self._last_update_limit_warn = None + self._updates: Deque[abcs.Update] = deque(maxlen=config.update_queue_limit) + self._downloader_map = object() + + if self_user := config.session.user: + self._dc_id = self_user.dc + if config.catch_up and config.session.state: + self._message_box.load(config.session.state) + def takeout(self) -> None: takeout(self) @@ -169,9 +194,6 @@ class Client: async def set_receive_updates(self) -> None: await set_receive_updates(self) - def run_until_disconnected(self) -> None: - run_until_disconnected(self) - def on(self) -> None: on(self) @@ -211,29 +233,31 @@ class Client: async def get_peer_id(self) -> None: await get_peer_id(self) - def loop(self) -> None: - loop(self) - - def disconnected(self) -> None: - disconnected(self) - - def flood_sleep_threshold(self) -> None: - flood_sleep_threshold(self) - async def connect(self) -> None: await connect(self) - def is_connected(self) -> None: - is_connected(self) + async def disconnect(self) -> None: + await disconnect(self) - def disconnect(self) -> None: - disconnect(self) + async def __call__(self, request: Request[Return]) -> Return: + if not self._sender: + raise ConnectionError("not connected") - def set_proxy(self) -> None: - set_proxy(self) + return await invoke_request(self, self._sender, self._sender_lock, request) - async def __aenter__(self) -> None: - raise NotImplementedError + async def step(self) -> None: + await step(self) + + async def run_until_disconnected(self) -> None: + await run_until_disconnected(self) + + @property + def connected(self) -> bool: + return connected(self) + + async def __aenter__(self) -> Self: + await self.connect() + return self async def __aexit__( self, @@ -242,4 +266,4 @@ class Client: tb: Optional[TracebackType], ) -> None: exc_type, exc, tb - raise NotImplementedError + await self.disconnect() diff --git a/client/src/telethon/_impl/client/client/net.py b/client/src/telethon/_impl/client/client/net.py index 921baa6e..2d48081d 100644 --- a/client/src/telethon/_impl/client/client/net.py +++ b/client/src/telethon/_impl/client/client/net.py @@ -1,41 +1,192 @@ from __future__ import annotations -from typing import TYPE_CHECKING +import asyncio +import platform +import re +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, Optional, TypeVar + +from ....version import __version__ +from ...mtproto.mtp.types import RpcError +from ...mtproto.transport.full import Full +from ...mtsender.sender import Sender +from ...mtsender.sender import connect as connect_without_auth +from ...mtsender.sender import connect_with_auth +from ...session.message_box.defs import DataCenter, Session +from ...tl import LAYER, functions +from ...tl.core.request import Request if TYPE_CHECKING: from .client import Client -def loop(self: Client) -> None: - self - raise NotImplementedError +Return = TypeVar("Return") -def disconnected(self: Client) -> None: - self - raise NotImplementedError +def default_device_model() -> str: + system = platform.uname() + if system.machine in ("x86_64", "AMD64"): + return "PC 64bit" + elif system.machine in ("i386", "i686", "x86"): + return "PC 32bit" + else: + return system.machine or "Unknown" -def flood_sleep_threshold(self: Client) -> None: - self - raise NotImplementedError +def default_system_version() -> str: + system = platform.uname() + return re.sub(r"-.+", "", system.release) or "1.0" + + +@dataclass +class Config: + session: Session + api_id: int + api_hash: str + device_model: str = field(default_factory=default_device_model) + system_version: str = field(default_factory=default_system_version) + app_version: str = __version__ + system_lang_code: str = "en" + lang_code: str = "en" + catch_up: bool = False + server_addr: Optional[str] = None + flood_sleep_threshold: Optional[int] = 60 + update_queue_limit: Optional[int] = None + + +# dc_id to IPv4 and port pair +DC_ADDRESSES = [ + "0.0.0.0:0", + "149.154.175.53:443", + "149.154.167.51:443", + "149.154.175.100:443", + "149.154.167.92:443", + "91.108.56.190:443", +] + +DEFAULT_DC = 2 + + +async def connect_sender(dc_id: int, config: Config) -> Sender: + transport = Full() + + if config.server_addr: + addr = config.server_addr + else: + addr = DC_ADDRESSES[dc_id] + + auth_key: Optional[bytes] = None + for dc in config.session.dcs: + if dc.id == dc_id: + if dc.auth: + auth_key = dc.auth + break + + if auth_key: + sender = await connect_with_auth(transport, addr, auth_key) + else: + sender = await connect_without_auth(transport, addr) + for dc in config.session.dcs: + if dc.id == dc_id: + dc.auth = sender.auth_key + break + else: + config.session.dcs.append( + DataCenter(id=dc_id, addr=addr, auth=sender.auth_key) + ) + + # TODO handle -404 (we had a previously-valid authkey, but server no longer knows about it) + # TODO all up-to-date server addresses should be stored in the session for future initial connections + remote_config = await sender.invoke( + functions.invoke_with_layer( + layer=LAYER, + query=functions.init_connection( + api_id=config.api_id, + device_model=config.device_model, + system_version=config.system_version, + app_version=config.app_version, + system_lang_code=config.system_lang_code, + lang_pack="", + lang_code=config.lang_code, + proxy=None, + params=None, + query=functions.help.get_config(), + ), + ) + ) + remote_config + + return sender async def connect(self: Client) -> None: - self - raise NotImplementedError + self._sender = await connect_sender(self._dc_id, self._config) + + if self._message_box.is_empty() and self._config.session.user: + try: + await self(functions.updates.get_state()) + except Exception: + pass -def is_connected(self: Client) -> None: - self - raise NotImplementedError +async def disconnect(self: Client) -> None: + if not self._sender: + return + + await self._sender.disconnect() + self._sender = None -def disconnect(self: Client) -> None: - self - raise NotImplementedError +async def invoke_request( + self: Client, + sender: Sender, + lock: asyncio.Lock, + request: Request[Return], +) -> Return: + slept_flood = False + sleep_thresh = self._config.flood_sleep_threshold or 0 + rx = sender.enqueue(request) + while True: + while not rx.done(): + await step_sender(self, sender, lock) + try: + response = rx.result() + break + except RpcError as e: + if ( + e.code == 420 + and e.value is not None + and not slept_flood + and e.value < sleep_thresh + ): + await asyncio.sleep(e.value) + slept_flood = True + rx = sender.enqueue(request) + continue + else: + raise + return request.deserialize_response(response) -def set_proxy(self: Client) -> None: - self - raise NotImplementedError +async def step(self: Client) -> None: + if self._sender: + await step_sender(self, self._sender, self._sender_lock) + + +async def step_sender(self: Client, sender: Sender, lock: asyncio.Lock) -> None: + if lock.locked(): + async with lock: + pass + else: + async with lock: + updates = await sender.step() + # self._process_socket_updates(updates) + + +async def run_until_disconnected(self: Client) -> None: + while self.connected: + await self.step() + + +def connected(self: Client) -> bool: + return self._sender is not None diff --git a/client/tests/client_test.py b/client/tests/client_test.py new file mode 100644 index 00000000..ea5b9c4a --- /dev/null +++ b/client/tests/client_test.py @@ -0,0 +1,37 @@ +import asyncio +import os +import random + +from telethon._impl.client.client.client import Client +from telethon._impl.client.client.net import Config +from telethon._impl.session.message_box.defs import Session +from telethon._impl.tl.mtproto import functions, types + + +def test_ping_pong() -> None: + async def func() -> None: + api_id = os.getenv("TG_ID") + api_hash = os.getenv("TG_HASH") + assert api_id and api_id.isdigit() + assert api_hash + client = Client( + Config( + session=Session( + dcs=[], + user=None, + state=None, + ), + api_id=int(api_id), + api_hash=api_hash, + ) + ) + assert not client.connected + await client.connect() + assert client.connected + + ping_id = random.randrange(-(2**63), 2**63) + pong = await client(functions.ping(ping_id=ping_id)) + assert isinstance(pong, types.Pong) + assert pong.ping_id == ping_id + + asyncio.run(func())