Implement client connection

This commit is contained in:
Lonami Exo 2023-09-01 13:42:23 +02:00
parent 87ffdca4c2
commit 294f7dedd5
3 changed files with 261 additions and 49 deletions

View File

@ -1,6 +1,13 @@
import asyncio
from collections import deque
from types import TracebackType from types import TracebackType
from typing import Optional, Type from typing import Deque, Optional, Self, Type, TypeVar
from ...mtsender.sender import Sender
from ...session.chat.hash_cache import ChatHashCache
from ...session.message_box.messagebox import MessageBox
from ...tl import abcs
from ...tl.core.request import Request
from .account import edit_2fa, end_takeout, takeout from .account import edit_2fa, end_takeout, takeout
from .auth import log_out, qr_login, send_code_request, sign_in, sign_up, start from .auth import log_out, qr_login, send_code_request, sign_in, sign_up, start
from .bots import inline_query from .bots import inline_query
@ -29,13 +36,14 @@ from .messages import (
unpin_message, unpin_message,
) )
from .net import ( from .net import (
DEFAULT_DC,
Config,
connect, connect,
connected,
disconnect, disconnect,
disconnected, invoke_request,
flood_sleep_threshold, run_until_disconnected,
is_connected, step,
loop,
set_proxy,
) )
from .updates import ( from .updates import (
add_event_handler, add_event_handler,
@ -43,7 +51,6 @@ from .updates import (
list_event_handlers, list_event_handlers,
on, on,
remove_event_handler, remove_event_handler,
run_until_disconnected,
set_receive_updates, set_receive_updates,
) )
from .uploads import send_file, upload_file from .uploads import send_file, upload_file
@ -56,8 +63,26 @@ from .users import (
is_user_authorized, is_user_authorized,
) )
Return = TypeVar("Return")
class Client: class Client:
def __init__(self, config: Config) -> None:
self._sender: Optional[Sender] = None
self._sender_lock = asyncio.Lock()
self._dc_id = DEFAULT_DC
self._config = config
self._message_box = MessageBox()
self._chat_hashes = ChatHashCache(None)
self._last_update_limit_warn = None
self._updates: Deque[abcs.Update] = deque(maxlen=config.update_queue_limit)
self._downloader_map = object()
if self_user := config.session.user:
self._dc_id = self_user.dc
if config.catch_up and config.session.state:
self._message_box.load(config.session.state)
def takeout(self) -> None: def takeout(self) -> None:
takeout(self) takeout(self)
@ -169,9 +194,6 @@ class Client:
async def set_receive_updates(self) -> None: async def set_receive_updates(self) -> None:
await set_receive_updates(self) await set_receive_updates(self)
def run_until_disconnected(self) -> None:
run_until_disconnected(self)
def on(self) -> None: def on(self) -> None:
on(self) on(self)
@ -211,29 +233,31 @@ class Client:
async def get_peer_id(self) -> None: async def get_peer_id(self) -> None:
await get_peer_id(self) await get_peer_id(self)
def loop(self) -> None:
loop(self)
def disconnected(self) -> None:
disconnected(self)
def flood_sleep_threshold(self) -> None:
flood_sleep_threshold(self)
async def connect(self) -> None: async def connect(self) -> None:
await connect(self) await connect(self)
def is_connected(self) -> None: async def disconnect(self) -> None:
is_connected(self) await disconnect(self)
def disconnect(self) -> None: async def __call__(self, request: Request[Return]) -> Return:
disconnect(self) if not self._sender:
raise ConnectionError("not connected")
def set_proxy(self) -> None: return await invoke_request(self, self._sender, self._sender_lock, request)
set_proxy(self)
async def __aenter__(self) -> None: async def step(self) -> None:
raise NotImplementedError await step(self)
async def run_until_disconnected(self) -> None:
await run_until_disconnected(self)
@property
def connected(self) -> bool:
return connected(self)
async def __aenter__(self) -> Self:
await self.connect()
return self
async def __aexit__( async def __aexit__(
self, self,
@ -242,4 +266,4 @@ class Client:
tb: Optional[TracebackType], tb: Optional[TracebackType],
) -> None: ) -> None:
exc_type, exc, tb exc_type, exc, tb
raise NotImplementedError await self.disconnect()

View File

@ -1,41 +1,192 @@
from __future__ import annotations from __future__ import annotations
from typing import TYPE_CHECKING import asyncio
import platform
import re
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Optional, TypeVar
from ....version import __version__
from ...mtproto.mtp.types import RpcError
from ...mtproto.transport.full import Full
from ...mtsender.sender import Sender
from ...mtsender.sender import connect as connect_without_auth
from ...mtsender.sender import connect_with_auth
from ...session.message_box.defs import DataCenter, Session
from ...tl import LAYER, functions
from ...tl.core.request import Request
if TYPE_CHECKING: if TYPE_CHECKING:
from .client import Client from .client import Client
def loop(self: Client) -> None: Return = TypeVar("Return")
self
raise NotImplementedError
def disconnected(self: Client) -> None: def default_device_model() -> str:
self system = platform.uname()
raise NotImplementedError if system.machine in ("x86_64", "AMD64"):
return "PC 64bit"
elif system.machine in ("i386", "i686", "x86"):
return "PC 32bit"
else:
return system.machine or "Unknown"
def flood_sleep_threshold(self: Client) -> None: def default_system_version() -> str:
self system = platform.uname()
raise NotImplementedError return re.sub(r"-.+", "", system.release) or "1.0"
@dataclass
class Config:
session: Session
api_id: int
api_hash: str
device_model: str = field(default_factory=default_device_model)
system_version: str = field(default_factory=default_system_version)
app_version: str = __version__
system_lang_code: str = "en"
lang_code: str = "en"
catch_up: bool = False
server_addr: Optional[str] = None
flood_sleep_threshold: Optional[int] = 60
update_queue_limit: Optional[int] = None
# dc_id to IPv4 and port pair
DC_ADDRESSES = [
"0.0.0.0:0",
"149.154.175.53:443",
"149.154.167.51:443",
"149.154.175.100:443",
"149.154.167.92:443",
"91.108.56.190:443",
]
DEFAULT_DC = 2
async def connect_sender(dc_id: int, config: Config) -> Sender:
transport = Full()
if config.server_addr:
addr = config.server_addr
else:
addr = DC_ADDRESSES[dc_id]
auth_key: Optional[bytes] = None
for dc in config.session.dcs:
if dc.id == dc_id:
if dc.auth:
auth_key = dc.auth
break
if auth_key:
sender = await connect_with_auth(transport, addr, auth_key)
else:
sender = await connect_without_auth(transport, addr)
for dc in config.session.dcs:
if dc.id == dc_id:
dc.auth = sender.auth_key
break
else:
config.session.dcs.append(
DataCenter(id=dc_id, addr=addr, auth=sender.auth_key)
)
# TODO handle -404 (we had a previously-valid authkey, but server no longer knows about it)
# TODO all up-to-date server addresses should be stored in the session for future initial connections
remote_config = await sender.invoke(
functions.invoke_with_layer(
layer=LAYER,
query=functions.init_connection(
api_id=config.api_id,
device_model=config.device_model,
system_version=config.system_version,
app_version=config.app_version,
system_lang_code=config.system_lang_code,
lang_pack="",
lang_code=config.lang_code,
proxy=None,
params=None,
query=functions.help.get_config(),
),
)
)
remote_config
return sender
async def connect(self: Client) -> None: async def connect(self: Client) -> None:
self self._sender = await connect_sender(self._dc_id, self._config)
raise NotImplementedError
if self._message_box.is_empty() and self._config.session.user:
try:
await self(functions.updates.get_state())
except Exception:
pass
def is_connected(self: Client) -> None: async def disconnect(self: Client) -> None:
self if not self._sender:
raise NotImplementedError return
await self._sender.disconnect()
self._sender = None
def disconnect(self: Client) -> None: async def invoke_request(
self self: Client,
raise NotImplementedError sender: Sender,
lock: asyncio.Lock,
request: Request[Return],
) -> Return:
slept_flood = False
sleep_thresh = self._config.flood_sleep_threshold or 0
rx = sender.enqueue(request)
while True:
while not rx.done():
await step_sender(self, sender, lock)
try:
response = rx.result()
break
except RpcError as e:
if (
e.code == 420
and e.value is not None
and not slept_flood
and e.value < sleep_thresh
):
await asyncio.sleep(e.value)
slept_flood = True
rx = sender.enqueue(request)
continue
else:
raise
return request.deserialize_response(response)
def set_proxy(self: Client) -> None: async def step(self: Client) -> None:
self if self._sender:
raise NotImplementedError await step_sender(self, self._sender, self._sender_lock)
async def step_sender(self: Client, sender: Sender, lock: asyncio.Lock) -> None:
if lock.locked():
async with lock:
pass
else:
async with lock:
updates = await sender.step()
# self._process_socket_updates(updates)
async def run_until_disconnected(self: Client) -> None:
while self.connected:
await self.step()
def connected(self: Client) -> bool:
return self._sender is not None

View File

@ -0,0 +1,37 @@
import asyncio
import os
import random
from telethon._impl.client.client.client import Client
from telethon._impl.client.client.net import Config
from telethon._impl.session.message_box.defs import Session
from telethon._impl.tl.mtproto import functions, types
def test_ping_pong() -> None:
async def func() -> None:
api_id = os.getenv("TG_ID")
api_hash = os.getenv("TG_HASH")
assert api_id and api_id.isdigit()
assert api_hash
client = Client(
Config(
session=Session(
dcs=[],
user=None,
state=None,
),
api_id=int(api_id),
api_hash=api_hash,
)
)
assert not client.connected
await client.connect()
assert client.connected
ping_id = random.randrange(-(2**63), 2**63)
pong = await client(functions.ping(ping_id=ping_id))
assert isinstance(pong, types.Pong)
assert pong.ping_id == ping_id
asyncio.run(func())