mirror of
https://github.com/LonamiWebs/Telethon.git
synced 2025-01-24 08:14:14 +03:00
Implement client connection
This commit is contained in:
parent
87ffdca4c2
commit
294f7dedd5
|
@ -1,6 +1,13 @@
|
|||
import asyncio
|
||||
from collections import deque
|
||||
from types import TracebackType
|
||||
from typing import Optional, Type
|
||||
from typing import Deque, Optional, Self, Type, TypeVar
|
||||
|
||||
from ...mtsender.sender import Sender
|
||||
from ...session.chat.hash_cache import ChatHashCache
|
||||
from ...session.message_box.messagebox import MessageBox
|
||||
from ...tl import abcs
|
||||
from ...tl.core.request import Request
|
||||
from .account import edit_2fa, end_takeout, takeout
|
||||
from .auth import log_out, qr_login, send_code_request, sign_in, sign_up, start
|
||||
from .bots import inline_query
|
||||
|
@ -29,13 +36,14 @@ from .messages import (
|
|||
unpin_message,
|
||||
)
|
||||
from .net import (
|
||||
DEFAULT_DC,
|
||||
Config,
|
||||
connect,
|
||||
connected,
|
||||
disconnect,
|
||||
disconnected,
|
||||
flood_sleep_threshold,
|
||||
is_connected,
|
||||
loop,
|
||||
set_proxy,
|
||||
invoke_request,
|
||||
run_until_disconnected,
|
||||
step,
|
||||
)
|
||||
from .updates import (
|
||||
add_event_handler,
|
||||
|
@ -43,7 +51,6 @@ from .updates import (
|
|||
list_event_handlers,
|
||||
on,
|
||||
remove_event_handler,
|
||||
run_until_disconnected,
|
||||
set_receive_updates,
|
||||
)
|
||||
from .uploads import send_file, upload_file
|
||||
|
@ -56,8 +63,26 @@ from .users import (
|
|||
is_user_authorized,
|
||||
)
|
||||
|
||||
Return = TypeVar("Return")
|
||||
|
||||
|
||||
class Client:
|
||||
def __init__(self, config: Config) -> None:
|
||||
self._sender: Optional[Sender] = None
|
||||
self._sender_lock = asyncio.Lock()
|
||||
self._dc_id = DEFAULT_DC
|
||||
self._config = config
|
||||
self._message_box = MessageBox()
|
||||
self._chat_hashes = ChatHashCache(None)
|
||||
self._last_update_limit_warn = None
|
||||
self._updates: Deque[abcs.Update] = deque(maxlen=config.update_queue_limit)
|
||||
self._downloader_map = object()
|
||||
|
||||
if self_user := config.session.user:
|
||||
self._dc_id = self_user.dc
|
||||
if config.catch_up and config.session.state:
|
||||
self._message_box.load(config.session.state)
|
||||
|
||||
def takeout(self) -> None:
|
||||
takeout(self)
|
||||
|
||||
|
@ -169,9 +194,6 @@ class Client:
|
|||
async def set_receive_updates(self) -> None:
|
||||
await set_receive_updates(self)
|
||||
|
||||
def run_until_disconnected(self) -> None:
|
||||
run_until_disconnected(self)
|
||||
|
||||
def on(self) -> None:
|
||||
on(self)
|
||||
|
||||
|
@ -211,29 +233,31 @@ class Client:
|
|||
async def get_peer_id(self) -> None:
|
||||
await get_peer_id(self)
|
||||
|
||||
def loop(self) -> None:
|
||||
loop(self)
|
||||
|
||||
def disconnected(self) -> None:
|
||||
disconnected(self)
|
||||
|
||||
def flood_sleep_threshold(self) -> None:
|
||||
flood_sleep_threshold(self)
|
||||
|
||||
async def connect(self) -> None:
|
||||
await connect(self)
|
||||
|
||||
def is_connected(self) -> None:
|
||||
is_connected(self)
|
||||
async def disconnect(self) -> None:
|
||||
await disconnect(self)
|
||||
|
||||
def disconnect(self) -> None:
|
||||
disconnect(self)
|
||||
async def __call__(self, request: Request[Return]) -> Return:
|
||||
if not self._sender:
|
||||
raise ConnectionError("not connected")
|
||||
|
||||
def set_proxy(self) -> None:
|
||||
set_proxy(self)
|
||||
return await invoke_request(self, self._sender, self._sender_lock, request)
|
||||
|
||||
async def __aenter__(self) -> None:
|
||||
raise NotImplementedError
|
||||
async def step(self) -> None:
|
||||
await step(self)
|
||||
|
||||
async def run_until_disconnected(self) -> None:
|
||||
await run_until_disconnected(self)
|
||||
|
||||
@property
|
||||
def connected(self) -> bool:
|
||||
return connected(self)
|
||||
|
||||
async def __aenter__(self) -> Self:
|
||||
await self.connect()
|
||||
return self
|
||||
|
||||
async def __aexit__(
|
||||
self,
|
||||
|
@ -242,4 +266,4 @@ class Client:
|
|||
tb: Optional[TracebackType],
|
||||
) -> None:
|
||||
exc_type, exc, tb
|
||||
raise NotImplementedError
|
||||
await self.disconnect()
|
||||
|
|
|
@ -1,41 +1,192 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
import asyncio
|
||||
import platform
|
||||
import re
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING, Optional, TypeVar
|
||||
|
||||
from ....version import __version__
|
||||
from ...mtproto.mtp.types import RpcError
|
||||
from ...mtproto.transport.full import Full
|
||||
from ...mtsender.sender import Sender
|
||||
from ...mtsender.sender import connect as connect_without_auth
|
||||
from ...mtsender.sender import connect_with_auth
|
||||
from ...session.message_box.defs import DataCenter, Session
|
||||
from ...tl import LAYER, functions
|
||||
from ...tl.core.request import Request
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .client import Client
|
||||
|
||||
|
||||
def loop(self: Client) -> None:
|
||||
self
|
||||
raise NotImplementedError
|
||||
Return = TypeVar("Return")
|
||||
|
||||
|
||||
def disconnected(self: Client) -> None:
|
||||
self
|
||||
raise NotImplementedError
|
||||
def default_device_model() -> str:
|
||||
system = platform.uname()
|
||||
if system.machine in ("x86_64", "AMD64"):
|
||||
return "PC 64bit"
|
||||
elif system.machine in ("i386", "i686", "x86"):
|
||||
return "PC 32bit"
|
||||
else:
|
||||
return system.machine or "Unknown"
|
||||
|
||||
|
||||
def flood_sleep_threshold(self: Client) -> None:
|
||||
self
|
||||
raise NotImplementedError
|
||||
def default_system_version() -> str:
|
||||
system = platform.uname()
|
||||
return re.sub(r"-.+", "", system.release) or "1.0"
|
||||
|
||||
|
||||
@dataclass
|
||||
class Config:
|
||||
session: Session
|
||||
api_id: int
|
||||
api_hash: str
|
||||
device_model: str = field(default_factory=default_device_model)
|
||||
system_version: str = field(default_factory=default_system_version)
|
||||
app_version: str = __version__
|
||||
system_lang_code: str = "en"
|
||||
lang_code: str = "en"
|
||||
catch_up: bool = False
|
||||
server_addr: Optional[str] = None
|
||||
flood_sleep_threshold: Optional[int] = 60
|
||||
update_queue_limit: Optional[int] = None
|
||||
|
||||
|
||||
# dc_id to IPv4 and port pair
|
||||
DC_ADDRESSES = [
|
||||
"0.0.0.0:0",
|
||||
"149.154.175.53:443",
|
||||
"149.154.167.51:443",
|
||||
"149.154.175.100:443",
|
||||
"149.154.167.92:443",
|
||||
"91.108.56.190:443",
|
||||
]
|
||||
|
||||
DEFAULT_DC = 2
|
||||
|
||||
|
||||
async def connect_sender(dc_id: int, config: Config) -> Sender:
|
||||
transport = Full()
|
||||
|
||||
if config.server_addr:
|
||||
addr = config.server_addr
|
||||
else:
|
||||
addr = DC_ADDRESSES[dc_id]
|
||||
|
||||
auth_key: Optional[bytes] = None
|
||||
for dc in config.session.dcs:
|
||||
if dc.id == dc_id:
|
||||
if dc.auth:
|
||||
auth_key = dc.auth
|
||||
break
|
||||
|
||||
if auth_key:
|
||||
sender = await connect_with_auth(transport, addr, auth_key)
|
||||
else:
|
||||
sender = await connect_without_auth(transport, addr)
|
||||
for dc in config.session.dcs:
|
||||
if dc.id == dc_id:
|
||||
dc.auth = sender.auth_key
|
||||
break
|
||||
else:
|
||||
config.session.dcs.append(
|
||||
DataCenter(id=dc_id, addr=addr, auth=sender.auth_key)
|
||||
)
|
||||
|
||||
# TODO handle -404 (we had a previously-valid authkey, but server no longer knows about it)
|
||||
# TODO all up-to-date server addresses should be stored in the session for future initial connections
|
||||
remote_config = await sender.invoke(
|
||||
functions.invoke_with_layer(
|
||||
layer=LAYER,
|
||||
query=functions.init_connection(
|
||||
api_id=config.api_id,
|
||||
device_model=config.device_model,
|
||||
system_version=config.system_version,
|
||||
app_version=config.app_version,
|
||||
system_lang_code=config.system_lang_code,
|
||||
lang_pack="",
|
||||
lang_code=config.lang_code,
|
||||
proxy=None,
|
||||
params=None,
|
||||
query=functions.help.get_config(),
|
||||
),
|
||||
)
|
||||
)
|
||||
remote_config
|
||||
|
||||
return sender
|
||||
|
||||
|
||||
async def connect(self: Client) -> None:
|
||||
self
|
||||
raise NotImplementedError
|
||||
self._sender = await connect_sender(self._dc_id, self._config)
|
||||
|
||||
if self._message_box.is_empty() and self._config.session.user:
|
||||
try:
|
||||
await self(functions.updates.get_state())
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
def is_connected(self: Client) -> None:
|
||||
self
|
||||
raise NotImplementedError
|
||||
async def disconnect(self: Client) -> None:
|
||||
if not self._sender:
|
||||
return
|
||||
|
||||
await self._sender.disconnect()
|
||||
self._sender = None
|
||||
|
||||
|
||||
def disconnect(self: Client) -> None:
|
||||
self
|
||||
raise NotImplementedError
|
||||
async def invoke_request(
|
||||
self: Client,
|
||||
sender: Sender,
|
||||
lock: asyncio.Lock,
|
||||
request: Request[Return],
|
||||
) -> Return:
|
||||
slept_flood = False
|
||||
sleep_thresh = self._config.flood_sleep_threshold or 0
|
||||
rx = sender.enqueue(request)
|
||||
while True:
|
||||
while not rx.done():
|
||||
await step_sender(self, sender, lock)
|
||||
try:
|
||||
response = rx.result()
|
||||
break
|
||||
except RpcError as e:
|
||||
if (
|
||||
e.code == 420
|
||||
and e.value is not None
|
||||
and not slept_flood
|
||||
and e.value < sleep_thresh
|
||||
):
|
||||
await asyncio.sleep(e.value)
|
||||
slept_flood = True
|
||||
rx = sender.enqueue(request)
|
||||
continue
|
||||
else:
|
||||
raise
|
||||
return request.deserialize_response(response)
|
||||
|
||||
|
||||
def set_proxy(self: Client) -> None:
|
||||
self
|
||||
raise NotImplementedError
|
||||
async def step(self: Client) -> None:
|
||||
if self._sender:
|
||||
await step_sender(self, self._sender, self._sender_lock)
|
||||
|
||||
|
||||
async def step_sender(self: Client, sender: Sender, lock: asyncio.Lock) -> None:
|
||||
if lock.locked():
|
||||
async with lock:
|
||||
pass
|
||||
else:
|
||||
async with lock:
|
||||
updates = await sender.step()
|
||||
# self._process_socket_updates(updates)
|
||||
|
||||
|
||||
async def run_until_disconnected(self: Client) -> None:
|
||||
while self.connected:
|
||||
await self.step()
|
||||
|
||||
|
||||
def connected(self: Client) -> bool:
|
||||
return self._sender is not None
|
||||
|
|
37
client/tests/client_test.py
Normal file
37
client/tests/client_test.py
Normal file
|
@ -0,0 +1,37 @@
|
|||
import asyncio
|
||||
import os
|
||||
import random
|
||||
|
||||
from telethon._impl.client.client.client import Client
|
||||
from telethon._impl.client.client.net import Config
|
||||
from telethon._impl.session.message_box.defs import Session
|
||||
from telethon._impl.tl.mtproto import functions, types
|
||||
|
||||
|
||||
def test_ping_pong() -> None:
|
||||
async def func() -> None:
|
||||
api_id = os.getenv("TG_ID")
|
||||
api_hash = os.getenv("TG_HASH")
|
||||
assert api_id and api_id.isdigit()
|
||||
assert api_hash
|
||||
client = Client(
|
||||
Config(
|
||||
session=Session(
|
||||
dcs=[],
|
||||
user=None,
|
||||
state=None,
|
||||
),
|
||||
api_id=int(api_id),
|
||||
api_hash=api_hash,
|
||||
)
|
||||
)
|
||||
assert not client.connected
|
||||
await client.connect()
|
||||
assert client.connected
|
||||
|
||||
ping_id = random.randrange(-(2**63), 2**63)
|
||||
pong = await client(functions.ping(ping_id=ping_id))
|
||||
assert isinstance(pong, types.Pong)
|
||||
assert pong.ping_id == ping_id
|
||||
|
||||
asyncio.run(func())
|
Loading…
Reference in New Issue
Block a user