mirror of
https://github.com/LonamiWebs/Telethon.git
synced 2025-01-24 16:24:15 +03:00
Improve handling of remote config
This commit is contained in:
parent
b4f9d3d720
commit
6ed279e773
|
@ -6,10 +6,11 @@ from typing import TYPE_CHECKING, Optional, Union
|
|||
|
||||
from ...crypto import two_factor_auth
|
||||
from ...mtproto import RpcError
|
||||
from ...session import DataCenter
|
||||
from ...session import User as SessionUser
|
||||
from ...tl import abcs, functions, types
|
||||
from ..types import LoginToken, PasswordToken, User
|
||||
from .net import connect_sender, datacenter_for_id
|
||||
from .net import connect_sender
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .client import Client
|
||||
|
@ -50,7 +51,7 @@ async def complete_login(client: Client, auth: abcs.auth.Authorization) -> User:
|
|||
async def handle_migrate(client: Client, dc_id: Optional[int]) -> None:
|
||||
assert dc_id is not None
|
||||
sender, client._session.dcs = await connect_sender(
|
||||
client._config, datacenter_for_id(client, dc_id), client._logger
|
||||
client._config, client._session.dcs, DataCenter(id=dc_id), client._logger
|
||||
)
|
||||
async with client._sender_lock:
|
||||
client._sender = sender
|
||||
|
|
|
@ -6,16 +6,16 @@ import logging
|
|||
import platform
|
||||
import re
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING, List, Optional, Tuple, TypeVar
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, TypeVar
|
||||
|
||||
from ....version import __version__
|
||||
from ...mtproto import Full, RpcError
|
||||
from ...mtproto import BadStatus, Full, RpcError
|
||||
from ...mtsender import Sender
|
||||
from ...mtsender import connect as connect_without_auth
|
||||
from ...mtsender import connect_with_auth
|
||||
from ...session import DataCenter
|
||||
from ...session import User as SessionUser
|
||||
from ...tl import LAYER, Request, functions, types
|
||||
from ...tl import LAYER, Request, abcs, functions, types
|
||||
from ..errors import adapt_rpc
|
||||
from .updates import dispatcher, process_socket_updates
|
||||
|
||||
|
@ -56,72 +56,114 @@ class Config:
|
|||
update_queue_limit: Optional[int] = None
|
||||
|
||||
|
||||
KNOWN_DC = [
|
||||
DataCenter(id=1, addr="149.154.175.53:443", auth=None),
|
||||
DataCenter(id=2, addr="149.154.167.51:443", auth=None),
|
||||
DataCenter(id=3, addr="149.154.175.100:443", auth=None),
|
||||
DataCenter(id=4, addr="149.154.167.92:443", auth=None),
|
||||
DataCenter(id=5, addr="91.108.56.190:443", auth=None),
|
||||
KNOWN_DCS = [
|
||||
DataCenter(id=1, ipv4_addr="149.154.175.53:443", ipv6_addr=None, auth=None),
|
||||
DataCenter(id=2, ipv4_addr="149.154.167.51:443", ipv6_addr=None, auth=None),
|
||||
DataCenter(id=3, ipv4_addr="149.154.175.100:443", ipv6_addr=None, auth=None),
|
||||
DataCenter(id=4, ipv4_addr="149.154.167.92:443", ipv6_addr=None, auth=None),
|
||||
DataCenter(id=5, ipv4_addr="91.108.56.190:443", ipv6_addr=None, auth=None),
|
||||
]
|
||||
|
||||
DEFAULT_DC = 2
|
||||
|
||||
|
||||
def as_concrete_dc_option(opt: abcs.DcOption) -> types.DcOption:
|
||||
assert isinstance(opt, types.DcOption)
|
||||
return opt
|
||||
|
||||
|
||||
async def connect_sender(
|
||||
config: Config,
|
||||
known_dcs: List[DataCenter],
|
||||
dc: DataCenter,
|
||||
base_logger: logging.Logger,
|
||||
force_auth_gen: bool = False,
|
||||
) -> Tuple[Sender, List[DataCenter]]:
|
||||
transport = Full()
|
||||
|
||||
if dc.auth:
|
||||
sender = await connect_with_auth(
|
||||
transport, dc.id, dc.addr, dc.auth, base_logger
|
||||
)
|
||||
else:
|
||||
sender = await connect_without_auth(transport, dc.id, dc.addr, base_logger)
|
||||
|
||||
# TODO handle -404 (we had a previously-valid authkey, but server no longer knows about it)
|
||||
remote_config = await sender.invoke(
|
||||
functions.invoke_with_layer(
|
||||
layer=LAYER,
|
||||
query=functions.init_connection(
|
||||
api_id=config.api_id,
|
||||
device_model=config.device_model,
|
||||
system_version=config.system_version,
|
||||
app_version=config.app_version,
|
||||
system_lang_code=config.system_lang_code,
|
||||
lang_pack="",
|
||||
lang_code=config.lang_code,
|
||||
proxy=None,
|
||||
params=None,
|
||||
query=functions.help.get_config(),
|
||||
),
|
||||
)
|
||||
# Only the ID of the input DC may be known.
|
||||
# Find the corresponding address and authentication key if needed.
|
||||
addr = dc.ipv4_addr or next(
|
||||
d.ipv4_addr
|
||||
for d in itertools.chain(known_dcs, KNOWN_DCS)
|
||||
if d.id == dc.id and d.ipv4_addr
|
||||
)
|
||||
auth = (
|
||||
None
|
||||
if force_auth_gen
|
||||
else dc.auth
|
||||
or (next((d.auth for d in known_dcs if d.id == dc.id and d.auth), None))
|
||||
)
|
||||
|
||||
latest_dcs = []
|
||||
append_current = True
|
||||
for opt in types.Config.from_bytes(remote_config).dc_options:
|
||||
assert isinstance(opt, types.DcOption)
|
||||
latest_dcs.append(
|
||||
DataCenter(
|
||||
id=opt.id,
|
||||
addr=opt.ip_address,
|
||||
auth=sender.auth_key if sender.dc_id == opt.id else None,
|
||||
transport = Full()
|
||||
if auth:
|
||||
sender = await connect_with_auth(transport, dc.id, addr, auth, base_logger)
|
||||
else:
|
||||
sender = await connect_without_auth(transport, dc.id, addr, base_logger)
|
||||
|
||||
try:
|
||||
remote_config_data = await sender.invoke(
|
||||
functions.invoke_with_layer(
|
||||
layer=LAYER,
|
||||
query=functions.init_connection(
|
||||
api_id=config.api_id,
|
||||
device_model=config.device_model,
|
||||
system_version=config.system_version,
|
||||
app_version=config.app_version,
|
||||
system_lang_code=config.system_lang_code,
|
||||
lang_pack="",
|
||||
lang_code=config.lang_code,
|
||||
proxy=None,
|
||||
params=None,
|
||||
query=functions.help.get_config(),
|
||||
),
|
||||
)
|
||||
)
|
||||
if sender.dc_id == opt.id:
|
||||
append_current = False
|
||||
except BadStatus as e:
|
||||
if e.status == 404 and auth:
|
||||
dc = DataCenter(
|
||||
id=dc.id, ipv4_addr=dc.ipv4_addr, ipv6_addr=dc.ipv6_addr, auth=None
|
||||
)
|
||||
base_logger.warning(
|
||||
"datacenter could not find stored auth; will retry generating a new one: %s",
|
||||
dc,
|
||||
)
|
||||
return await connect_sender(
|
||||
config, known_dcs, dc, base_logger, force_auth_gen=True
|
||||
)
|
||||
else:
|
||||
raise
|
||||
|
||||
if append_current:
|
||||
# Current config has no DC with current ID.
|
||||
# Append it to preserve the authorization key.
|
||||
latest_dcs.append(
|
||||
DataCenter(id=sender.dc_id, addr=sender.addr, auth=sender.auth_key)
|
||||
)
|
||||
remote_config = types.Config.from_bytes(remote_config_data)
|
||||
|
||||
return sender, latest_dcs
|
||||
# Filter the primary data-centers to persist, static first.
|
||||
dc_options = [
|
||||
opt
|
||||
for opt in map(as_concrete_dc_option, remote_config.dc_options)
|
||||
if not (opt.media_only or opt.tcpo_only or opt.cdn)
|
||||
]
|
||||
dc_options.sort(key=lambda opt: opt.static, reverse=True)
|
||||
|
||||
latest_dcs: Dict[int, DataCenter] = {}
|
||||
for opt in dc_options:
|
||||
dc = latest_dcs.setdefault(opt.id, DataCenter(id=opt.id))
|
||||
if opt.ipv6:
|
||||
if not dc.ipv6_addr:
|
||||
dc.ipv6_addr = f"{opt.ip_address}:{opt.port}"
|
||||
else:
|
||||
if not dc.ipv4_addr:
|
||||
dc.ipv4_addr = f"{opt.ip_address}:{opt.port}"
|
||||
|
||||
# Restore only missing information.
|
||||
for dc in itertools.chain(
|
||||
[DataCenter(id=sender.dc_id, ipv4_addr=sender.addr, auth=sender.auth_key)],
|
||||
known_dcs,
|
||||
):
|
||||
saved_dc = latest_dcs.setdefault(sender.dc_id, DataCenter(id=dc.id))
|
||||
saved_dc.ipv4_addr = saved_dc.ipv4_addr or dc.ipv4_addr
|
||||
saved_dc.ipv6_addr = saved_dc.ipv6_addr or dc.ipv6_addr
|
||||
saved_dc.auth = saved_dc.auth or dc.auth
|
||||
|
||||
session_dcs = [dc for _, dc in sorted(latest_dcs.items(), key=lambda t: t[0])]
|
||||
return sender, session_dcs
|
||||
|
||||
|
||||
async def connect(self: Client) -> None:
|
||||
|
@ -131,28 +173,11 @@ async def connect(self: Client) -> None:
|
|||
if session := await self._storage.load():
|
||||
self._session = session
|
||||
|
||||
if dc := self._config.datacenter:
|
||||
# Datacenter override, reusing the session's auth-key unless already present.
|
||||
datacenter = (
|
||||
dc
|
||||
if dc.auth
|
||||
else DataCenter(
|
||||
id=dc.id,
|
||||
addr=dc.addr,
|
||||
auth=next(
|
||||
(d.auth for d in self._session.dcs if d.id == dc.id and d.auth),
|
||||
None,
|
||||
),
|
||||
)
|
||||
)
|
||||
else:
|
||||
# Reuse the session's datacenter, falling back to defaults if not found.
|
||||
datacenter = datacenter_for_id(
|
||||
self, self._session.user.dc if self._session.user else DEFAULT_DC
|
||||
)
|
||||
|
||||
datacenter = self._config.datacenter or DataCenter(
|
||||
id=self._session.user.dc if self._session.user else DEFAULT_DC
|
||||
)
|
||||
self._sender, self._session.dcs = await connect_sender(
|
||||
self._config, datacenter, self._logger
|
||||
self._config, self._session.dcs, datacenter, self._logger
|
||||
)
|
||||
|
||||
if self._message_box.is_empty() and self._session.user:
|
||||
|
@ -177,17 +202,6 @@ async def connect(self: Client) -> None:
|
|||
self._dispatcher = asyncio.create_task(dispatcher(self))
|
||||
|
||||
|
||||
def datacenter_for_id(client: Client, dc_id: int) -> DataCenter:
|
||||
try:
|
||||
return next(
|
||||
dc
|
||||
for dc in itertools.chain(client._session.dcs, KNOWN_DC)
|
||||
if dc.id == dc_id
|
||||
)
|
||||
except StopIteration:
|
||||
raise ValueError(f"no datacenter found for id: {dc_id}") from None
|
||||
|
||||
|
||||
async def disconnect(self: Client) -> None:
|
||||
if not self._sender:
|
||||
return
|
||||
|
|
|
@ -1,25 +1,48 @@
|
|||
from typing import List, Optional
|
||||
|
||||
from ..tl.core.serializable import obj_repr
|
||||
|
||||
|
||||
class DataCenter:
|
||||
"""
|
||||
Data-center information.
|
||||
|
||||
:param id: See below.
|
||||
:param addr: See below.
|
||||
:param ipv4_addr: See below.
|
||||
:param ipv6_addr: See below.
|
||||
:param auth: See below.
|
||||
"""
|
||||
|
||||
__slots__ = ("id", "addr", "auth")
|
||||
__slots__ = ("id", "ipv4_addr", "ipv6_addr", "auth")
|
||||
|
||||
def __init__(self, *, id: int, addr: str, auth: Optional[bytes]) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
id: int,
|
||||
ipv4_addr: Optional[str] = None,
|
||||
ipv6_addr: Optional[str] = None,
|
||||
auth: Optional[bytes] = None,
|
||||
) -> None:
|
||||
self.id = id
|
||||
"The DC identifier."
|
||||
self.addr = addr
|
||||
"The server address of the DC, in ``'ip:port'`` format."
|
||||
self.ipv4_addr = ipv4_addr
|
||||
"The IPv4 socket server address of the DC, in ``'ip:port'`` format."
|
||||
self.ipv6_addr = ipv6_addr
|
||||
"The IPv6 socket server address of the DC, in ``'ip:port'`` format."
|
||||
self.auth = auth
|
||||
"Authentication key to encrypt communication with."
|
||||
|
||||
def __repr__(self) -> str:
|
||||
# Censor auth
|
||||
return obj_repr(
|
||||
DataCenter(
|
||||
id=self.id,
|
||||
ipv4_addr=self.ipv4_addr,
|
||||
ipv6_addr=self.ipv6_addr,
|
||||
auth=b"..." if self.auth else None,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class User:
|
||||
"""
|
||||
|
@ -43,6 +66,9 @@ class User:
|
|||
self.username = username
|
||||
"User's primary username."
|
||||
|
||||
__repr__ = obj_repr
|
||||
__str__ = __repr__
|
||||
|
||||
|
||||
class ChannelState:
|
||||
"""
|
||||
|
@ -60,6 +86,9 @@ class ChannelState:
|
|||
self.pts = pts
|
||||
"The channel's partial sequence number."
|
||||
|
||||
__repr__ = obj_repr
|
||||
__str__ = __repr__
|
||||
|
||||
|
||||
class UpdateState:
|
||||
"""
|
||||
|
@ -100,6 +129,9 @@ class UpdateState:
|
|||
self.channels = channels
|
||||
"Update state for channels."
|
||||
|
||||
__repr__ = obj_repr
|
||||
__str__ = __repr__
|
||||
|
||||
|
||||
class Session:
|
||||
"""
|
||||
|
@ -143,3 +175,6 @@ class Session:
|
|||
"Information about the logged-in user."
|
||||
self.state = state
|
||||
"Update state."
|
||||
|
||||
__repr__ = obj_repr
|
||||
__str__ = __repr__
|
||||
|
|
|
@ -79,7 +79,7 @@ class SqliteSession(Storage):
|
|||
|
||||
return Session(
|
||||
dcs=[
|
||||
DataCenter(id=id, addr=f"{ip}:{port}", auth=auth)
|
||||
DataCenter(id=id, ipv4_addr=f"{ip}:{port}", ipv6_addr=None, auth=auth)
|
||||
for (id, ip, port, auth) in sessions
|
||||
],
|
||||
user=None,
|
||||
|
@ -105,8 +105,8 @@ class SqliteSession(Storage):
|
|||
|
||||
return Session(
|
||||
dcs=[
|
||||
DataCenter(id=id, addr=addr, auth=auth)
|
||||
for (id, addr, auth) in datacenter
|
||||
DataCenter(id=id, ipv4_addr=ipv4_addr, ipv6_addr=ipv6_addr, auth=auth)
|
||||
for (id, ipv4_addr, ipv6_addr, auth) in datacenter
|
||||
],
|
||||
user=User(id=user[0], dc=user[1], bot=bool(user[2]), username=user[3])
|
||||
if user
|
||||
|
@ -129,8 +129,8 @@ class SqliteSession(Storage):
|
|||
c.execute("delete from state")
|
||||
c.execute("delete from channelstate")
|
||||
c.executemany(
|
||||
"insert into datacenter values (?, ?, ?)",
|
||||
[(dc.id, dc.addr, dc.auth) for dc in session.dcs],
|
||||
"insert into datacenter values (?, ?, ?, ?)",
|
||||
[(dc.id, dc.ipv4_addr, dc.ipv6_addr, dc.auth) for dc in session.dcs],
|
||||
)
|
||||
if user := session.user:
|
||||
c.execute(
|
||||
|
@ -188,7 +188,8 @@ class SqliteSession(Storage):
|
|||
);
|
||||
create table datacenter(
|
||||
id integer primary key,
|
||||
addr text not null,
|
||||
ipv4_addr text,
|
||||
ipv6_addr text,
|
||||
auth blob
|
||||
);
|
||||
create table user(
|
||||
|
|
|
@ -1,10 +1,20 @@
|
|||
import abc
|
||||
import struct
|
||||
from typing import Self, Tuple
|
||||
from typing import Protocol, Self, Tuple
|
||||
|
||||
from .reader import Reader
|
||||
|
||||
|
||||
class HasSlots(Protocol):
|
||||
__slots__: Tuple[str, ...]
|
||||
|
||||
|
||||
def obj_repr(obj: HasSlots) -> str:
|
||||
fields = ((attr, getattr(obj, attr)) for attr in obj.__slots__)
|
||||
params = ", ".join(f"{name}={field!r}" for name, field in fields)
|
||||
return f"{obj.__class__.__name__}({params})"
|
||||
|
||||
|
||||
class Serializable(abc.ABC):
|
||||
__slots__: Tuple[str, ...] = ()
|
||||
|
||||
|
@ -34,10 +44,7 @@ class Serializable(abc.ABC):
|
|||
self._write_boxed_to(buffer)
|
||||
return bytes(buffer)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
fields = ((attr, getattr(self, attr)) for attr in self.__slots__)
|
||||
attrs = ", ".join(f"{name}={field!r}" for name, field in fields)
|
||||
return f"{self.__class__.__name__}({attrs})"
|
||||
__repr__ = obj_repr
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
if not isinstance(other, self.__class__):
|
||||
|
|
Loading…
Reference in New Issue
Block a user