Improve handling of remote config

This commit is contained in:
Lonami Exo 2023-10-13 22:55:42 +02:00
parent b4f9d3d720
commit 6ed279e773
5 changed files with 161 additions and 103 deletions

View File

@ -6,10 +6,11 @@ from typing import TYPE_CHECKING, Optional, Union
from ...crypto import two_factor_auth
from ...mtproto import RpcError
from ...session import DataCenter
from ...session import User as SessionUser
from ...tl import abcs, functions, types
from ..types import LoginToken, PasswordToken, User
from .net import connect_sender, datacenter_for_id
from .net import connect_sender
if TYPE_CHECKING:
from .client import Client
@ -50,7 +51,7 @@ async def complete_login(client: Client, auth: abcs.auth.Authorization) -> User:
async def handle_migrate(client: Client, dc_id: Optional[int]) -> None:
assert dc_id is not None
sender, client._session.dcs = await connect_sender(
client._config, datacenter_for_id(client, dc_id), client._logger
client._config, client._session.dcs, DataCenter(id=dc_id), client._logger
)
async with client._sender_lock:
client._sender = sender

View File

@ -6,16 +6,16 @@ import logging
import platform
import re
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, List, Optional, Tuple, TypeVar
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, TypeVar
from ....version import __version__
from ...mtproto import Full, RpcError
from ...mtproto import BadStatus, Full, RpcError
from ...mtsender import Sender
from ...mtsender import connect as connect_without_auth
from ...mtsender import connect_with_auth
from ...session import DataCenter
from ...session import User as SessionUser
from ...tl import LAYER, Request, functions, types
from ...tl import LAYER, Request, abcs, functions, types
from ..errors import adapt_rpc
from .updates import dispatcher, process_socket_updates
@ -56,33 +56,51 @@ class Config:
update_queue_limit: Optional[int] = None
KNOWN_DC = [
DataCenter(id=1, addr="149.154.175.53:443", auth=None),
DataCenter(id=2, addr="149.154.167.51:443", auth=None),
DataCenter(id=3, addr="149.154.175.100:443", auth=None),
DataCenter(id=4, addr="149.154.167.92:443", auth=None),
DataCenter(id=5, addr="91.108.56.190:443", auth=None),
KNOWN_DCS = [
DataCenter(id=1, ipv4_addr="149.154.175.53:443", ipv6_addr=None, auth=None),
DataCenter(id=2, ipv4_addr="149.154.167.51:443", ipv6_addr=None, auth=None),
DataCenter(id=3, ipv4_addr="149.154.175.100:443", ipv6_addr=None, auth=None),
DataCenter(id=4, ipv4_addr="149.154.167.92:443", ipv6_addr=None, auth=None),
DataCenter(id=5, ipv4_addr="91.108.56.190:443", ipv6_addr=None, auth=None),
]
DEFAULT_DC = 2
def as_concrete_dc_option(opt: abcs.DcOption) -> types.DcOption:
assert isinstance(opt, types.DcOption)
return opt
async def connect_sender(
config: Config,
known_dcs: List[DataCenter],
dc: DataCenter,
base_logger: logging.Logger,
force_auth_gen: bool = False,
) -> Tuple[Sender, List[DataCenter]]:
transport = Full()
if dc.auth:
sender = await connect_with_auth(
transport, dc.id, dc.addr, dc.auth, base_logger
# Only the ID of the input DC may be known.
# Find the corresponding address and authentication key if needed.
addr = dc.ipv4_addr or next(
d.ipv4_addr
for d in itertools.chain(known_dcs, KNOWN_DCS)
if d.id == dc.id and d.ipv4_addr
)
auth = (
None
if force_auth_gen
else dc.auth
or (next((d.auth for d in known_dcs if d.id == dc.id and d.auth), None))
)
else:
sender = await connect_without_auth(transport, dc.id, dc.addr, base_logger)
# TODO handle -404 (we had a previously-valid authkey, but server no longer knows about it)
remote_config = await sender.invoke(
transport = Full()
if auth:
sender = await connect_with_auth(transport, dc.id, addr, auth, base_logger)
else:
sender = await connect_without_auth(transport, dc.id, addr, base_logger)
try:
remote_config_data = await sender.invoke(
functions.invoke_with_layer(
layer=LAYER,
query=functions.init_connection(
@ -99,29 +117,53 @@ async def connect_sender(
),
)
)
latest_dcs = []
append_current = True
for opt in types.Config.from_bytes(remote_config).dc_options:
assert isinstance(opt, types.DcOption)
latest_dcs.append(
DataCenter(
id=opt.id,
addr=opt.ip_address,
auth=sender.auth_key if sender.dc_id == opt.id else None,
except BadStatus as e:
if e.status == 404 and auth:
dc = DataCenter(
id=dc.id, ipv4_addr=dc.ipv4_addr, ipv6_addr=dc.ipv6_addr, auth=None
)
base_logger.warning(
"datacenter could not find stored auth; will retry generating a new one: %s",
dc,
)
if sender.dc_id == opt.id:
append_current = False
if append_current:
# Current config has no DC with current ID.
# Append it to preserve the authorization key.
latest_dcs.append(
DataCenter(id=sender.dc_id, addr=sender.addr, auth=sender.auth_key)
return await connect_sender(
config, known_dcs, dc, base_logger, force_auth_gen=True
)
else:
raise
return sender, latest_dcs
remote_config = types.Config.from_bytes(remote_config_data)
# Filter the primary data-centers to persist, static first.
dc_options = [
opt
for opt in map(as_concrete_dc_option, remote_config.dc_options)
if not (opt.media_only or opt.tcpo_only or opt.cdn)
]
dc_options.sort(key=lambda opt: opt.static, reverse=True)
latest_dcs: Dict[int, DataCenter] = {}
for opt in dc_options:
dc = latest_dcs.setdefault(opt.id, DataCenter(id=opt.id))
if opt.ipv6:
if not dc.ipv6_addr:
dc.ipv6_addr = f"{opt.ip_address}:{opt.port}"
else:
if not dc.ipv4_addr:
dc.ipv4_addr = f"{opt.ip_address}:{opt.port}"
# Restore only missing information.
for dc in itertools.chain(
[DataCenter(id=sender.dc_id, ipv4_addr=sender.addr, auth=sender.auth_key)],
known_dcs,
):
saved_dc = latest_dcs.setdefault(sender.dc_id, DataCenter(id=dc.id))
saved_dc.ipv4_addr = saved_dc.ipv4_addr or dc.ipv4_addr
saved_dc.ipv6_addr = saved_dc.ipv6_addr or dc.ipv6_addr
saved_dc.auth = saved_dc.auth or dc.auth
session_dcs = [dc for _, dc in sorted(latest_dcs.items(), key=lambda t: t[0])]
return sender, session_dcs
async def connect(self: Client) -> None:
@ -131,28 +173,11 @@ async def connect(self: Client) -> None:
if session := await self._storage.load():
self._session = session
if dc := self._config.datacenter:
# Datacenter override, reusing the session's auth-key unless already present.
datacenter = (
dc
if dc.auth
else DataCenter(
id=dc.id,
addr=dc.addr,
auth=next(
(d.auth for d in self._session.dcs if d.id == dc.id and d.auth),
None,
),
datacenter = self._config.datacenter or DataCenter(
id=self._session.user.dc if self._session.user else DEFAULT_DC
)
)
else:
# Reuse the session's datacenter, falling back to defaults if not found.
datacenter = datacenter_for_id(
self, self._session.user.dc if self._session.user else DEFAULT_DC
)
self._sender, self._session.dcs = await connect_sender(
self._config, datacenter, self._logger
self._config, self._session.dcs, datacenter, self._logger
)
if self._message_box.is_empty() and self._session.user:
@ -177,17 +202,6 @@ async def connect(self: Client) -> None:
self._dispatcher = asyncio.create_task(dispatcher(self))
def datacenter_for_id(client: Client, dc_id: int) -> DataCenter:
try:
return next(
dc
for dc in itertools.chain(client._session.dcs, KNOWN_DC)
if dc.id == dc_id
)
except StopIteration:
raise ValueError(f"no datacenter found for id: {dc_id}") from None
async def disconnect(self: Client) -> None:
if not self._sender:
return

View File

@ -1,25 +1,48 @@
from typing import List, Optional
from ..tl.core.serializable import obj_repr
class DataCenter:
"""
Data-center information.
:param id: See below.
:param addr: See below.
:param ipv4_addr: See below.
:param ipv6_addr: See below.
:param auth: See below.
"""
__slots__ = ("id", "addr", "auth")
__slots__ = ("id", "ipv4_addr", "ipv6_addr", "auth")
def __init__(self, *, id: int, addr: str, auth: Optional[bytes]) -> None:
def __init__(
self,
*,
id: int,
ipv4_addr: Optional[str] = None,
ipv6_addr: Optional[str] = None,
auth: Optional[bytes] = None,
) -> None:
self.id = id
"The DC identifier."
self.addr = addr
"The server address of the DC, in ``'ip:port'`` format."
self.ipv4_addr = ipv4_addr
"The IPv4 socket server address of the DC, in ``'ip:port'`` format."
self.ipv6_addr = ipv6_addr
"The IPv6 socket server address of the DC, in ``'ip:port'`` format."
self.auth = auth
"Authentication key to encrypt communication with."
def __repr__(self) -> str:
# Censor auth
return obj_repr(
DataCenter(
id=self.id,
ipv4_addr=self.ipv4_addr,
ipv6_addr=self.ipv6_addr,
auth=b"..." if self.auth else None,
)
)
class User:
"""
@ -43,6 +66,9 @@ class User:
self.username = username
"User's primary username."
__repr__ = obj_repr
__str__ = __repr__
class ChannelState:
"""
@ -60,6 +86,9 @@ class ChannelState:
self.pts = pts
"The channel's partial sequence number."
__repr__ = obj_repr
__str__ = __repr__
class UpdateState:
"""
@ -100,6 +129,9 @@ class UpdateState:
self.channels = channels
"Update state for channels."
__repr__ = obj_repr
__str__ = __repr__
class Session:
"""
@ -143,3 +175,6 @@ class Session:
"Information about the logged-in user."
self.state = state
"Update state."
__repr__ = obj_repr
__str__ = __repr__

View File

@ -79,7 +79,7 @@ class SqliteSession(Storage):
return Session(
dcs=[
DataCenter(id=id, addr=f"{ip}:{port}", auth=auth)
DataCenter(id=id, ipv4_addr=f"{ip}:{port}", ipv6_addr=None, auth=auth)
for (id, ip, port, auth) in sessions
],
user=None,
@ -105,8 +105,8 @@ class SqliteSession(Storage):
return Session(
dcs=[
DataCenter(id=id, addr=addr, auth=auth)
for (id, addr, auth) in datacenter
DataCenter(id=id, ipv4_addr=ipv4_addr, ipv6_addr=ipv6_addr, auth=auth)
for (id, ipv4_addr, ipv6_addr, auth) in datacenter
],
user=User(id=user[0], dc=user[1], bot=bool(user[2]), username=user[3])
if user
@ -129,8 +129,8 @@ class SqliteSession(Storage):
c.execute("delete from state")
c.execute("delete from channelstate")
c.executemany(
"insert into datacenter values (?, ?, ?)",
[(dc.id, dc.addr, dc.auth) for dc in session.dcs],
"insert into datacenter values (?, ?, ?, ?)",
[(dc.id, dc.ipv4_addr, dc.ipv6_addr, dc.auth) for dc in session.dcs],
)
if user := session.user:
c.execute(
@ -188,7 +188,8 @@ class SqliteSession(Storage):
);
create table datacenter(
id integer primary key,
addr text not null,
ipv4_addr text,
ipv6_addr text,
auth blob
);
create table user(

View File

@ -1,10 +1,20 @@
import abc
import struct
from typing import Self, Tuple
from typing import Protocol, Self, Tuple
from .reader import Reader
class HasSlots(Protocol):
__slots__: Tuple[str, ...]
def obj_repr(obj: HasSlots) -> str:
fields = ((attr, getattr(obj, attr)) for attr in obj.__slots__)
params = ", ".join(f"{name}={field!r}" for name, field in fields)
return f"{obj.__class__.__name__}({params})"
class Serializable(abc.ABC):
__slots__: Tuple[str, ...] = ()
@ -34,10 +44,7 @@ class Serializable(abc.ABC):
self._write_boxed_to(buffer)
return bytes(buffer)
def __repr__(self) -> str:
fields = ((attr, getattr(self, attr)) for attr in self.__slots__)
attrs = ", ".join(f"{name}={field!r}" for name, field in fields)
return f"{self.__class__.__name__}({attrs})"
__repr__ = obj_repr
def __eq__(self, other: object) -> bool:
if not isinstance(other, self.__class__):