Improve handling of remote config

This commit is contained in:
Lonami Exo 2023-10-13 22:55:42 +02:00
parent b4f9d3d720
commit 6ed279e773
5 changed files with 161 additions and 103 deletions

View File

@ -6,10 +6,11 @@ from typing import TYPE_CHECKING, Optional, Union
from ...crypto import two_factor_auth from ...crypto import two_factor_auth
from ...mtproto import RpcError from ...mtproto import RpcError
from ...session import DataCenter
from ...session import User as SessionUser from ...session import User as SessionUser
from ...tl import abcs, functions, types from ...tl import abcs, functions, types
from ..types import LoginToken, PasswordToken, User from ..types import LoginToken, PasswordToken, User
from .net import connect_sender, datacenter_for_id from .net import connect_sender
if TYPE_CHECKING: if TYPE_CHECKING:
from .client import Client from .client import Client
@ -50,7 +51,7 @@ async def complete_login(client: Client, auth: abcs.auth.Authorization) -> User:
async def handle_migrate(client: Client, dc_id: Optional[int]) -> None: async def handle_migrate(client: Client, dc_id: Optional[int]) -> None:
assert dc_id is not None assert dc_id is not None
sender, client._session.dcs = await connect_sender( sender, client._session.dcs = await connect_sender(
client._config, datacenter_for_id(client, dc_id), client._logger client._config, client._session.dcs, DataCenter(id=dc_id), client._logger
) )
async with client._sender_lock: async with client._sender_lock:
client._sender = sender client._sender = sender

View File

@ -6,16 +6,16 @@ import logging
import platform import platform
import re import re
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import TYPE_CHECKING, List, Optional, Tuple, TypeVar from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, TypeVar
from ....version import __version__ from ....version import __version__
from ...mtproto import Full, RpcError from ...mtproto import BadStatus, Full, RpcError
from ...mtsender import Sender from ...mtsender import Sender
from ...mtsender import connect as connect_without_auth from ...mtsender import connect as connect_without_auth
from ...mtsender import connect_with_auth from ...mtsender import connect_with_auth
from ...session import DataCenter from ...session import DataCenter
from ...session import User as SessionUser from ...session import User as SessionUser
from ...tl import LAYER, Request, functions, types from ...tl import LAYER, Request, abcs, functions, types
from ..errors import adapt_rpc from ..errors import adapt_rpc
from .updates import dispatcher, process_socket_updates from .updates import dispatcher, process_socket_updates
@ -56,72 +56,114 @@ class Config:
update_queue_limit: Optional[int] = None update_queue_limit: Optional[int] = None
KNOWN_DC = [ KNOWN_DCS = [
DataCenter(id=1, addr="149.154.175.53:443", auth=None), DataCenter(id=1, ipv4_addr="149.154.175.53:443", ipv6_addr=None, auth=None),
DataCenter(id=2, addr="149.154.167.51:443", auth=None), DataCenter(id=2, ipv4_addr="149.154.167.51:443", ipv6_addr=None, auth=None),
DataCenter(id=3, addr="149.154.175.100:443", auth=None), DataCenter(id=3, ipv4_addr="149.154.175.100:443", ipv6_addr=None, auth=None),
DataCenter(id=4, addr="149.154.167.92:443", auth=None), DataCenter(id=4, ipv4_addr="149.154.167.92:443", ipv6_addr=None, auth=None),
DataCenter(id=5, addr="91.108.56.190:443", auth=None), DataCenter(id=5, ipv4_addr="91.108.56.190:443", ipv6_addr=None, auth=None),
] ]
DEFAULT_DC = 2 DEFAULT_DC = 2
def as_concrete_dc_option(opt: abcs.DcOption) -> types.DcOption:
assert isinstance(opt, types.DcOption)
return opt
async def connect_sender( async def connect_sender(
config: Config, config: Config,
known_dcs: List[DataCenter],
dc: DataCenter, dc: DataCenter,
base_logger: logging.Logger, base_logger: logging.Logger,
force_auth_gen: bool = False,
) -> Tuple[Sender, List[DataCenter]]: ) -> Tuple[Sender, List[DataCenter]]:
transport = Full() # Only the ID of the input DC may be known.
# Find the corresponding address and authentication key if needed.
if dc.auth: addr = dc.ipv4_addr or next(
sender = await connect_with_auth( d.ipv4_addr
transport, dc.id, dc.addr, dc.auth, base_logger for d in itertools.chain(known_dcs, KNOWN_DCS)
) if d.id == dc.id and d.ipv4_addr
else: )
sender = await connect_without_auth(transport, dc.id, dc.addr, base_logger) auth = (
None
# TODO handle -404 (we had a previously-valid authkey, but server no longer knows about it) if force_auth_gen
remote_config = await sender.invoke( else dc.auth
functions.invoke_with_layer( or (next((d.auth for d in known_dcs if d.id == dc.id and d.auth), None))
layer=LAYER,
query=functions.init_connection(
api_id=config.api_id,
device_model=config.device_model,
system_version=config.system_version,
app_version=config.app_version,
system_lang_code=config.system_lang_code,
lang_pack="",
lang_code=config.lang_code,
proxy=None,
params=None,
query=functions.help.get_config(),
),
)
) )
latest_dcs = [] transport = Full()
append_current = True if auth:
for opt in types.Config.from_bytes(remote_config).dc_options: sender = await connect_with_auth(transport, dc.id, addr, auth, base_logger)
assert isinstance(opt, types.DcOption) else:
latest_dcs.append( sender = await connect_without_auth(transport, dc.id, addr, base_logger)
DataCenter(
id=opt.id, try:
addr=opt.ip_address, remote_config_data = await sender.invoke(
auth=sender.auth_key if sender.dc_id == opt.id else None, functions.invoke_with_layer(
layer=LAYER,
query=functions.init_connection(
api_id=config.api_id,
device_model=config.device_model,
system_version=config.system_version,
app_version=config.app_version,
system_lang_code=config.system_lang_code,
lang_pack="",
lang_code=config.lang_code,
proxy=None,
params=None,
query=functions.help.get_config(),
),
) )
) )
if sender.dc_id == opt.id: except BadStatus as e:
append_current = False if e.status == 404 and auth:
dc = DataCenter(
id=dc.id, ipv4_addr=dc.ipv4_addr, ipv6_addr=dc.ipv6_addr, auth=None
)
base_logger.warning(
"datacenter could not find stored auth; will retry generating a new one: %s",
dc,
)
return await connect_sender(
config, known_dcs, dc, base_logger, force_auth_gen=True
)
else:
raise
if append_current: remote_config = types.Config.from_bytes(remote_config_data)
# Current config has no DC with current ID.
# Append it to preserve the authorization key.
latest_dcs.append(
DataCenter(id=sender.dc_id, addr=sender.addr, auth=sender.auth_key)
)
return sender, latest_dcs # Filter the primary data-centers to persist, static first.
dc_options = [
opt
for opt in map(as_concrete_dc_option, remote_config.dc_options)
if not (opt.media_only or opt.tcpo_only or opt.cdn)
]
dc_options.sort(key=lambda opt: opt.static, reverse=True)
latest_dcs: Dict[int, DataCenter] = {}
for opt in dc_options:
dc = latest_dcs.setdefault(opt.id, DataCenter(id=opt.id))
if opt.ipv6:
if not dc.ipv6_addr:
dc.ipv6_addr = f"{opt.ip_address}:{opt.port}"
else:
if not dc.ipv4_addr:
dc.ipv4_addr = f"{opt.ip_address}:{opt.port}"
# Restore only missing information.
for dc in itertools.chain(
[DataCenter(id=sender.dc_id, ipv4_addr=sender.addr, auth=sender.auth_key)],
known_dcs,
):
saved_dc = latest_dcs.setdefault(sender.dc_id, DataCenter(id=dc.id))
saved_dc.ipv4_addr = saved_dc.ipv4_addr or dc.ipv4_addr
saved_dc.ipv6_addr = saved_dc.ipv6_addr or dc.ipv6_addr
saved_dc.auth = saved_dc.auth or dc.auth
session_dcs = [dc for _, dc in sorted(latest_dcs.items(), key=lambda t: t[0])]
return sender, session_dcs
async def connect(self: Client) -> None: async def connect(self: Client) -> None:
@ -131,28 +173,11 @@ async def connect(self: Client) -> None:
if session := await self._storage.load(): if session := await self._storage.load():
self._session = session self._session = session
if dc := self._config.datacenter: datacenter = self._config.datacenter or DataCenter(
# Datacenter override, reusing the session's auth-key unless already present. id=self._session.user.dc if self._session.user else DEFAULT_DC
datacenter = ( )
dc
if dc.auth
else DataCenter(
id=dc.id,
addr=dc.addr,
auth=next(
(d.auth for d in self._session.dcs if d.id == dc.id and d.auth),
None,
),
)
)
else:
# Reuse the session's datacenter, falling back to defaults if not found.
datacenter = datacenter_for_id(
self, self._session.user.dc if self._session.user else DEFAULT_DC
)
self._sender, self._session.dcs = await connect_sender( self._sender, self._session.dcs = await connect_sender(
self._config, datacenter, self._logger self._config, self._session.dcs, datacenter, self._logger
) )
if self._message_box.is_empty() and self._session.user: if self._message_box.is_empty() and self._session.user:
@ -177,17 +202,6 @@ async def connect(self: Client) -> None:
self._dispatcher = asyncio.create_task(dispatcher(self)) self._dispatcher = asyncio.create_task(dispatcher(self))
def datacenter_for_id(client: Client, dc_id: int) -> DataCenter:
try:
return next(
dc
for dc in itertools.chain(client._session.dcs, KNOWN_DC)
if dc.id == dc_id
)
except StopIteration:
raise ValueError(f"no datacenter found for id: {dc_id}") from None
async def disconnect(self: Client) -> None: async def disconnect(self: Client) -> None:
if not self._sender: if not self._sender:
return return

View File

@ -1,25 +1,48 @@
from typing import List, Optional from typing import List, Optional
from ..tl.core.serializable import obj_repr
class DataCenter: class DataCenter:
""" """
Data-center information. Data-center information.
:param id: See below. :param id: See below.
:param addr: See below. :param ipv4_addr: See below.
:param ipv6_addr: See below.
:param auth: See below. :param auth: See below.
""" """
__slots__ = ("id", "addr", "auth") __slots__ = ("id", "ipv4_addr", "ipv6_addr", "auth")
def __init__(self, *, id: int, addr: str, auth: Optional[bytes]) -> None: def __init__(
self,
*,
id: int,
ipv4_addr: Optional[str] = None,
ipv6_addr: Optional[str] = None,
auth: Optional[bytes] = None,
) -> None:
self.id = id self.id = id
"The DC identifier." "The DC identifier."
self.addr = addr self.ipv4_addr = ipv4_addr
"The server address of the DC, in ``'ip:port'`` format." "The IPv4 socket server address of the DC, in ``'ip:port'`` format."
self.ipv6_addr = ipv6_addr
"The IPv6 socket server address of the DC, in ``'ip:port'`` format."
self.auth = auth self.auth = auth
"Authentication key to encrypt communication with." "Authentication key to encrypt communication with."
def __repr__(self) -> str:
# Censor auth
return obj_repr(
DataCenter(
id=self.id,
ipv4_addr=self.ipv4_addr,
ipv6_addr=self.ipv6_addr,
auth=b"..." if self.auth else None,
)
)
class User: class User:
""" """
@ -43,6 +66,9 @@ class User:
self.username = username self.username = username
"User's primary username." "User's primary username."
__repr__ = obj_repr
__str__ = __repr__
class ChannelState: class ChannelState:
""" """
@ -60,6 +86,9 @@ class ChannelState:
self.pts = pts self.pts = pts
"The channel's partial sequence number." "The channel's partial sequence number."
__repr__ = obj_repr
__str__ = __repr__
class UpdateState: class UpdateState:
""" """
@ -100,6 +129,9 @@ class UpdateState:
self.channels = channels self.channels = channels
"Update state for channels." "Update state for channels."
__repr__ = obj_repr
__str__ = __repr__
class Session: class Session:
""" """
@ -143,3 +175,6 @@ class Session:
"Information about the logged-in user." "Information about the logged-in user."
self.state = state self.state = state
"Update state." "Update state."
__repr__ = obj_repr
__str__ = __repr__

View File

@ -79,7 +79,7 @@ class SqliteSession(Storage):
return Session( return Session(
dcs=[ dcs=[
DataCenter(id=id, addr=f"{ip}:{port}", auth=auth) DataCenter(id=id, ipv4_addr=f"{ip}:{port}", ipv6_addr=None, auth=auth)
for (id, ip, port, auth) in sessions for (id, ip, port, auth) in sessions
], ],
user=None, user=None,
@ -105,8 +105,8 @@ class SqliteSession(Storage):
return Session( return Session(
dcs=[ dcs=[
DataCenter(id=id, addr=addr, auth=auth) DataCenter(id=id, ipv4_addr=ipv4_addr, ipv6_addr=ipv6_addr, auth=auth)
for (id, addr, auth) in datacenter for (id, ipv4_addr, ipv6_addr, auth) in datacenter
], ],
user=User(id=user[0], dc=user[1], bot=bool(user[2]), username=user[3]) user=User(id=user[0], dc=user[1], bot=bool(user[2]), username=user[3])
if user if user
@ -129,8 +129,8 @@ class SqliteSession(Storage):
c.execute("delete from state") c.execute("delete from state")
c.execute("delete from channelstate") c.execute("delete from channelstate")
c.executemany( c.executemany(
"insert into datacenter values (?, ?, ?)", "insert into datacenter values (?, ?, ?, ?)",
[(dc.id, dc.addr, dc.auth) for dc in session.dcs], [(dc.id, dc.ipv4_addr, dc.ipv6_addr, dc.auth) for dc in session.dcs],
) )
if user := session.user: if user := session.user:
c.execute( c.execute(
@ -188,7 +188,8 @@ class SqliteSession(Storage):
); );
create table datacenter( create table datacenter(
id integer primary key, id integer primary key,
addr text not null, ipv4_addr text,
ipv6_addr text,
auth blob auth blob
); );
create table user( create table user(

View File

@ -1,10 +1,20 @@
import abc import abc
import struct import struct
from typing import Self, Tuple from typing import Protocol, Self, Tuple
from .reader import Reader from .reader import Reader
class HasSlots(Protocol):
__slots__: Tuple[str, ...]
def obj_repr(obj: HasSlots) -> str:
fields = ((attr, getattr(obj, attr)) for attr in obj.__slots__)
params = ", ".join(f"{name}={field!r}" for name, field in fields)
return f"{obj.__class__.__name__}({params})"
class Serializable(abc.ABC): class Serializable(abc.ABC):
__slots__: Tuple[str, ...] = () __slots__: Tuple[str, ...] = ()
@ -34,10 +44,7 @@ class Serializable(abc.ABC):
self._write_boxed_to(buffer) self._write_boxed_to(buffer)
return bytes(buffer) return bytes(buffer)
def __repr__(self) -> str: __repr__ = obj_repr
fields = ((attr, getattr(self, attr)) for attr in self.__slots__)
attrs = ", ".join(f"{name}={field!r}" for name, field in fields)
return f"{self.__class__.__name__}({attrs})"
def __eq__(self, other: object) -> bool: def __eq__(self, other: object) -> bool:
if not isinstance(other, self.__class__): if not isinstance(other, self.__class__):