Implement client connection

This commit is contained in:
Lonami Exo 2023-09-01 13:42:23 +02:00
parent 87ffdca4c2
commit 294f7dedd5
3 changed files with 261 additions and 49 deletions

View File

@ -1,6 +1,13 @@
import asyncio
from collections import deque
from types import TracebackType
from typing import Optional, Type
from typing import Deque, Optional, Self, Type, TypeVar
from ...mtsender.sender import Sender
from ...session.chat.hash_cache import ChatHashCache
from ...session.message_box.messagebox import MessageBox
from ...tl import abcs
from ...tl.core.request import Request
from .account import edit_2fa, end_takeout, takeout
from .auth import log_out, qr_login, send_code_request, sign_in, sign_up, start
from .bots import inline_query
@ -29,13 +36,14 @@ from .messages import (
unpin_message,
)
from .net import (
DEFAULT_DC,
Config,
connect,
connected,
disconnect,
disconnected,
flood_sleep_threshold,
is_connected,
loop,
set_proxy,
invoke_request,
run_until_disconnected,
step,
)
from .updates import (
add_event_handler,
@ -43,7 +51,6 @@ from .updates import (
list_event_handlers,
on,
remove_event_handler,
run_until_disconnected,
set_receive_updates,
)
from .uploads import send_file, upload_file
@ -56,8 +63,26 @@ from .users import (
is_user_authorized,
)
Return = TypeVar("Return")
class Client:
def __init__(self, config: Config) -> None:
self._sender: Optional[Sender] = None
self._sender_lock = asyncio.Lock()
self._dc_id = DEFAULT_DC
self._config = config
self._message_box = MessageBox()
self._chat_hashes = ChatHashCache(None)
self._last_update_limit_warn = None
self._updates: Deque[abcs.Update] = deque(maxlen=config.update_queue_limit)
self._downloader_map = object()
if self_user := config.session.user:
self._dc_id = self_user.dc
if config.catch_up and config.session.state:
self._message_box.load(config.session.state)
def takeout(self) -> None:
takeout(self)
@ -169,9 +194,6 @@ class Client:
async def set_receive_updates(self) -> None:
await set_receive_updates(self)
def run_until_disconnected(self) -> None:
run_until_disconnected(self)
def on(self) -> None:
on(self)
@ -211,29 +233,31 @@ class Client:
async def get_peer_id(self) -> None:
await get_peer_id(self)
def loop(self) -> None:
loop(self)
def disconnected(self) -> None:
disconnected(self)
def flood_sleep_threshold(self) -> None:
flood_sleep_threshold(self)
async def connect(self) -> None:
await connect(self)
def is_connected(self) -> None:
is_connected(self)
async def disconnect(self) -> None:
await disconnect(self)
def disconnect(self) -> None:
disconnect(self)
async def __call__(self, request: Request[Return]) -> Return:
if not self._sender:
raise ConnectionError("not connected")
def set_proxy(self) -> None:
set_proxy(self)
return await invoke_request(self, self._sender, self._sender_lock, request)
async def __aenter__(self) -> None:
raise NotImplementedError
async def step(self) -> None:
await step(self)
async def run_until_disconnected(self) -> None:
await run_until_disconnected(self)
@property
def connected(self) -> bool:
return connected(self)
async def __aenter__(self) -> Self:
await self.connect()
return self
async def __aexit__(
self,
@ -242,4 +266,4 @@ class Client:
tb: Optional[TracebackType],
) -> None:
exc_type, exc, tb
raise NotImplementedError
await self.disconnect()

View File

@ -1,41 +1,192 @@
from __future__ import annotations
from typing import TYPE_CHECKING
import asyncio
import platform
import re
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Optional, TypeVar
from ....version import __version__
from ...mtproto.mtp.types import RpcError
from ...mtproto.transport.full import Full
from ...mtsender.sender import Sender
from ...mtsender.sender import connect as connect_without_auth
from ...mtsender.sender import connect_with_auth
from ...session.message_box.defs import DataCenter, Session
from ...tl import LAYER, functions
from ...tl.core.request import Request
if TYPE_CHECKING:
from .client import Client
def loop(self: Client) -> None:
self
raise NotImplementedError
Return = TypeVar("Return")
def disconnected(self: Client) -> None:
self
raise NotImplementedError
def default_device_model() -> str:
system = platform.uname()
if system.machine in ("x86_64", "AMD64"):
return "PC 64bit"
elif system.machine in ("i386", "i686", "x86"):
return "PC 32bit"
else:
return system.machine or "Unknown"
def flood_sleep_threshold(self: Client) -> None:
self
raise NotImplementedError
def default_system_version() -> str:
system = platform.uname()
return re.sub(r"-.+", "", system.release) or "1.0"
@dataclass
class Config:
session: Session
api_id: int
api_hash: str
device_model: str = field(default_factory=default_device_model)
system_version: str = field(default_factory=default_system_version)
app_version: str = __version__
system_lang_code: str = "en"
lang_code: str = "en"
catch_up: bool = False
server_addr: Optional[str] = None
flood_sleep_threshold: Optional[int] = 60
update_queue_limit: Optional[int] = None
# dc_id to IPv4 and port pair
DC_ADDRESSES = [
"0.0.0.0:0",
"149.154.175.53:443",
"149.154.167.51:443",
"149.154.175.100:443",
"149.154.167.92:443",
"91.108.56.190:443",
]
DEFAULT_DC = 2
async def connect_sender(dc_id: int, config: Config) -> Sender:
transport = Full()
if config.server_addr:
addr = config.server_addr
else:
addr = DC_ADDRESSES[dc_id]
auth_key: Optional[bytes] = None
for dc in config.session.dcs:
if dc.id == dc_id:
if dc.auth:
auth_key = dc.auth
break
if auth_key:
sender = await connect_with_auth(transport, addr, auth_key)
else:
sender = await connect_without_auth(transport, addr)
for dc in config.session.dcs:
if dc.id == dc_id:
dc.auth = sender.auth_key
break
else:
config.session.dcs.append(
DataCenter(id=dc_id, addr=addr, auth=sender.auth_key)
)
# TODO handle -404 (we had a previously-valid authkey, but server no longer knows about it)
# TODO all up-to-date server addresses should be stored in the session for future initial connections
remote_config = await sender.invoke(
functions.invoke_with_layer(
layer=LAYER,
query=functions.init_connection(
api_id=config.api_id,
device_model=config.device_model,
system_version=config.system_version,
app_version=config.app_version,
system_lang_code=config.system_lang_code,
lang_pack="",
lang_code=config.lang_code,
proxy=None,
params=None,
query=functions.help.get_config(),
),
)
)
remote_config
return sender
async def connect(self: Client) -> None:
self
raise NotImplementedError
self._sender = await connect_sender(self._dc_id, self._config)
if self._message_box.is_empty() and self._config.session.user:
try:
await self(functions.updates.get_state())
except Exception:
pass
def is_connected(self: Client) -> None:
self
raise NotImplementedError
async def disconnect(self: Client) -> None:
if not self._sender:
return
await self._sender.disconnect()
self._sender = None
def disconnect(self: Client) -> None:
self
raise NotImplementedError
async def invoke_request(
self: Client,
sender: Sender,
lock: asyncio.Lock,
request: Request[Return],
) -> Return:
slept_flood = False
sleep_thresh = self._config.flood_sleep_threshold or 0
rx = sender.enqueue(request)
while True:
while not rx.done():
await step_sender(self, sender, lock)
try:
response = rx.result()
break
except RpcError as e:
if (
e.code == 420
and e.value is not None
and not slept_flood
and e.value < sleep_thresh
):
await asyncio.sleep(e.value)
slept_flood = True
rx = sender.enqueue(request)
continue
else:
raise
return request.deserialize_response(response)
def set_proxy(self: Client) -> None:
self
raise NotImplementedError
async def step(self: Client) -> None:
if self._sender:
await step_sender(self, self._sender, self._sender_lock)
async def step_sender(self: Client, sender: Sender, lock: asyncio.Lock) -> None:
if lock.locked():
async with lock:
pass
else:
async with lock:
updates = await sender.step()
# self._process_socket_updates(updates)
async def run_until_disconnected(self: Client) -> None:
while self.connected:
await self.step()
def connected(self: Client) -> bool:
return self._sender is not None

View File

@ -0,0 +1,37 @@
import asyncio
import os
import random
from telethon._impl.client.client.client import Client
from telethon._impl.client.client.net import Config
from telethon._impl.session.message_box.defs import Session
from telethon._impl.tl.mtproto import functions, types
def test_ping_pong() -> None:
async def func() -> None:
api_id = os.getenv("TG_ID")
api_hash = os.getenv("TG_HASH")
assert api_id and api_id.isdigit()
assert api_hash
client = Client(
Config(
session=Session(
dcs=[],
user=None,
state=None,
),
api_id=int(api_id),
api_hash=api_hash,
)
)
assert not client.connected
await client.connect()
assert client.connected
ping_id = random.randrange(-(2**63), 2**63)
pong = await client(functions.ping(ping_id=ping_id))
assert isinstance(pong, types.Pong)
assert pong.ping_id == ping_id
asyncio.run(func())