diff --git a/client/src/telethon/_impl/client/client/auth.py b/client/src/telethon/_impl/client/client/auth.py index 45739a1d..4f2e3f52 100644 --- a/client/src/telethon/_impl/client/client/auth.py +++ b/client/src/telethon/_impl/client/client/auth.py @@ -6,10 +6,11 @@ from typing import TYPE_CHECKING, Optional, Union from ...crypto import two_factor_auth from ...mtproto import RpcError +from ...session import DataCenter from ...session import User as SessionUser from ...tl import abcs, functions, types from ..types import LoginToken, PasswordToken, User -from .net import connect_sender, datacenter_for_id +from .net import connect_sender if TYPE_CHECKING: from .client import Client @@ -50,7 +51,7 @@ async def complete_login(client: Client, auth: abcs.auth.Authorization) -> User: async def handle_migrate(client: Client, dc_id: Optional[int]) -> None: assert dc_id is not None sender, client._session.dcs = await connect_sender( - client._config, datacenter_for_id(client, dc_id), client._logger + client._config, client._session.dcs, DataCenter(id=dc_id), client._logger ) async with client._sender_lock: client._sender = sender diff --git a/client/src/telethon/_impl/client/client/net.py b/client/src/telethon/_impl/client/client/net.py index 48379d85..93e81999 100644 --- a/client/src/telethon/_impl/client/client/net.py +++ b/client/src/telethon/_impl/client/client/net.py @@ -6,16 +6,16 @@ import logging import platform import re from dataclasses import dataclass, field -from typing import TYPE_CHECKING, List, Optional, Tuple, TypeVar +from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, TypeVar from ....version import __version__ -from ...mtproto import Full, RpcError +from ...mtproto import BadStatus, Full, RpcError from ...mtsender import Sender from ...mtsender import connect as connect_without_auth from ...mtsender import connect_with_auth from ...session import DataCenter from ...session import User as SessionUser -from ...tl import LAYER, Request, functions, types +from ...tl import LAYER, Request, abcs, functions, types from ..errors import adapt_rpc from .updates import dispatcher, process_socket_updates @@ -56,72 +56,114 @@ class Config: update_queue_limit: Optional[int] = None -KNOWN_DC = [ - DataCenter(id=1, addr="149.154.175.53:443", auth=None), - DataCenter(id=2, addr="149.154.167.51:443", auth=None), - DataCenter(id=3, addr="149.154.175.100:443", auth=None), - DataCenter(id=4, addr="149.154.167.92:443", auth=None), - DataCenter(id=5, addr="91.108.56.190:443", auth=None), +KNOWN_DCS = [ + DataCenter(id=1, ipv4_addr="149.154.175.53:443", ipv6_addr=None, auth=None), + DataCenter(id=2, ipv4_addr="149.154.167.51:443", ipv6_addr=None, auth=None), + DataCenter(id=3, ipv4_addr="149.154.175.100:443", ipv6_addr=None, auth=None), + DataCenter(id=4, ipv4_addr="149.154.167.92:443", ipv6_addr=None, auth=None), + DataCenter(id=5, ipv4_addr="91.108.56.190:443", ipv6_addr=None, auth=None), ] DEFAULT_DC = 2 +def as_concrete_dc_option(opt: abcs.DcOption) -> types.DcOption: + assert isinstance(opt, types.DcOption) + return opt + + async def connect_sender( config: Config, + known_dcs: List[DataCenter], dc: DataCenter, base_logger: logging.Logger, + force_auth_gen: bool = False, ) -> Tuple[Sender, List[DataCenter]]: - transport = Full() - - if dc.auth: - sender = await connect_with_auth( - transport, dc.id, dc.addr, dc.auth, base_logger - ) - else: - sender = await connect_without_auth(transport, dc.id, dc.addr, base_logger) - - # TODO handle -404 (we had a previously-valid authkey, but server no longer knows about it) - remote_config = await sender.invoke( - functions.invoke_with_layer( - layer=LAYER, - query=functions.init_connection( - api_id=config.api_id, - device_model=config.device_model, - system_version=config.system_version, - app_version=config.app_version, - system_lang_code=config.system_lang_code, - lang_pack="", - lang_code=config.lang_code, - proxy=None, - params=None, - query=functions.help.get_config(), - ), - ) + # Only the ID of the input DC may be known. + # Find the corresponding address and authentication key if needed. + addr = dc.ipv4_addr or next( + d.ipv4_addr + for d in itertools.chain(known_dcs, KNOWN_DCS) + if d.id == dc.id and d.ipv4_addr + ) + auth = ( + None + if force_auth_gen + else dc.auth + or (next((d.auth for d in known_dcs if d.id == dc.id and d.auth), None)) ) - latest_dcs = [] - append_current = True - for opt in types.Config.from_bytes(remote_config).dc_options: - assert isinstance(opt, types.DcOption) - latest_dcs.append( - DataCenter( - id=opt.id, - addr=opt.ip_address, - auth=sender.auth_key if sender.dc_id == opt.id else None, + transport = Full() + if auth: + sender = await connect_with_auth(transport, dc.id, addr, auth, base_logger) + else: + sender = await connect_without_auth(transport, dc.id, addr, base_logger) + + try: + remote_config_data = await sender.invoke( + functions.invoke_with_layer( + layer=LAYER, + query=functions.init_connection( + api_id=config.api_id, + device_model=config.device_model, + system_version=config.system_version, + app_version=config.app_version, + system_lang_code=config.system_lang_code, + lang_pack="", + lang_code=config.lang_code, + proxy=None, + params=None, + query=functions.help.get_config(), + ), ) ) - if sender.dc_id == opt.id: - append_current = False + except BadStatus as e: + if e.status == 404 and auth: + dc = DataCenter( + id=dc.id, ipv4_addr=dc.ipv4_addr, ipv6_addr=dc.ipv6_addr, auth=None + ) + base_logger.warning( + "datacenter could not find stored auth; will retry generating a new one: %s", + dc, + ) + return await connect_sender( + config, known_dcs, dc, base_logger, force_auth_gen=True + ) + else: + raise - if append_current: - # Current config has no DC with current ID. - # Append it to preserve the authorization key. - latest_dcs.append( - DataCenter(id=sender.dc_id, addr=sender.addr, auth=sender.auth_key) - ) + remote_config = types.Config.from_bytes(remote_config_data) - return sender, latest_dcs + # Filter the primary data-centers to persist, static first. + dc_options = [ + opt + for opt in map(as_concrete_dc_option, remote_config.dc_options) + if not (opt.media_only or opt.tcpo_only or opt.cdn) + ] + dc_options.sort(key=lambda opt: opt.static, reverse=True) + + latest_dcs: Dict[int, DataCenter] = {} + for opt in dc_options: + dc = latest_dcs.setdefault(opt.id, DataCenter(id=opt.id)) + if opt.ipv6: + if not dc.ipv6_addr: + dc.ipv6_addr = f"{opt.ip_address}:{opt.port}" + else: + if not dc.ipv4_addr: + dc.ipv4_addr = f"{opt.ip_address}:{opt.port}" + + # Restore only missing information. + for dc in itertools.chain( + [DataCenter(id=sender.dc_id, ipv4_addr=sender.addr, auth=sender.auth_key)], + known_dcs, + ): + saved_dc = latest_dcs.setdefault(sender.dc_id, DataCenter(id=dc.id)) + saved_dc.ipv4_addr = saved_dc.ipv4_addr or dc.ipv4_addr + saved_dc.ipv6_addr = saved_dc.ipv6_addr or dc.ipv6_addr + saved_dc.auth = saved_dc.auth or dc.auth + + session_dcs = [dc for _, dc in sorted(latest_dcs.items(), key=lambda t: t[0])] + return sender, session_dcs async def connect(self: Client) -> None: @@ -131,28 +173,11 @@ async def connect(self: Client) -> None: if session := await self._storage.load(): self._session = session - if dc := self._config.datacenter: - # Datacenter override, reusing the session's auth-key unless already present. - datacenter = ( - dc - if dc.auth - else DataCenter( - id=dc.id, - addr=dc.addr, - auth=next( - (d.auth for d in self._session.dcs if d.id == dc.id and d.auth), - None, - ), - ) - ) - else: - # Reuse the session's datacenter, falling back to defaults if not found. - datacenter = datacenter_for_id( - self, self._session.user.dc if self._session.user else DEFAULT_DC - ) - + datacenter = self._config.datacenter or DataCenter( + id=self._session.user.dc if self._session.user else DEFAULT_DC + ) self._sender, self._session.dcs = await connect_sender( - self._config, datacenter, self._logger + self._config, self._session.dcs, datacenter, self._logger ) if self._message_box.is_empty() and self._session.user: @@ -177,17 +202,6 @@ async def connect(self: Client) -> None: self._dispatcher = asyncio.create_task(dispatcher(self)) -def datacenter_for_id(client: Client, dc_id: int) -> DataCenter: - try: - return next( - dc - for dc in itertools.chain(client._session.dcs, KNOWN_DC) - if dc.id == dc_id - ) - except StopIteration: - raise ValueError(f"no datacenter found for id: {dc_id}") from None - - async def disconnect(self: Client) -> None: if not self._sender: return diff --git a/client/src/telethon/_impl/session/session.py b/client/src/telethon/_impl/session/session.py index 9a008e60..cbfd2160 100644 --- a/client/src/telethon/_impl/session/session.py +++ b/client/src/telethon/_impl/session/session.py @@ -1,25 +1,48 @@ from typing import List, Optional +from ..tl.core.serializable import obj_repr + class DataCenter: """ Data-center information. :param id: See below. - :param addr: See below. + :param ipv4_addr: See below. + :param ipv6_addr: See below. :param auth: See below. """ - __slots__ = ("id", "addr", "auth") + __slots__ = ("id", "ipv4_addr", "ipv6_addr", "auth") - def __init__(self, *, id: int, addr: str, auth: Optional[bytes]) -> None: + def __init__( + self, + *, + id: int, + ipv4_addr: Optional[str] = None, + ipv6_addr: Optional[str] = None, + auth: Optional[bytes] = None, + ) -> None: self.id = id "The DC identifier." - self.addr = addr - "The server address of the DC, in ``'ip:port'`` format." + self.ipv4_addr = ipv4_addr + "The IPv4 socket server address of the DC, in ``'ip:port'`` format." + self.ipv6_addr = ipv6_addr + "The IPv6 socket server address of the DC, in ``'ip:port'`` format." self.auth = auth "Authentication key to encrypt communication with." + def __repr__(self) -> str: + # Censor auth + return obj_repr( + DataCenter( + id=self.id, + ipv4_addr=self.ipv4_addr, + ipv6_addr=self.ipv6_addr, + auth=b"..." if self.auth else None, + ) + ) + class User: """ @@ -43,6 +66,9 @@ class User: self.username = username "User's primary username." + __repr__ = obj_repr + __str__ = __repr__ + class ChannelState: """ @@ -60,6 +86,9 @@ class ChannelState: self.pts = pts "The channel's partial sequence number." + __repr__ = obj_repr + __str__ = __repr__ + class UpdateState: """ @@ -100,6 +129,9 @@ class UpdateState: self.channels = channels "Update state for channels." + __repr__ = obj_repr + __str__ = __repr__ + class Session: """ @@ -143,3 +175,6 @@ class Session: "Information about the logged-in user." self.state = state "Update state." + + __repr__ = obj_repr + __str__ = __repr__ diff --git a/client/src/telethon/_impl/session/storage/sqlite.py b/client/src/telethon/_impl/session/storage/sqlite.py index 78aca875..70f682cc 100644 --- a/client/src/telethon/_impl/session/storage/sqlite.py +++ b/client/src/telethon/_impl/session/storage/sqlite.py @@ -79,7 +79,7 @@ class SqliteSession(Storage): return Session( dcs=[ - DataCenter(id=id, addr=f"{ip}:{port}", auth=auth) + DataCenter(id=id, ipv4_addr=f"{ip}:{port}", ipv6_addr=None, auth=auth) for (id, ip, port, auth) in sessions ], user=None, @@ -105,8 +105,8 @@ class SqliteSession(Storage): return Session( dcs=[ - DataCenter(id=id, addr=addr, auth=auth) - for (id, addr, auth) in datacenter + DataCenter(id=id, ipv4_addr=ipv4_addr, ipv6_addr=ipv6_addr, auth=auth) + for (id, ipv4_addr, ipv6_addr, auth) in datacenter ], user=User(id=user[0], dc=user[1], bot=bool(user[2]), username=user[3]) if user @@ -129,8 +129,8 @@ class SqliteSession(Storage): c.execute("delete from state") c.execute("delete from channelstate") c.executemany( - "insert into datacenter values (?, ?, ?)", - [(dc.id, dc.addr, dc.auth) for dc in session.dcs], + "insert into datacenter values (?, ?, ?, ?)", + [(dc.id, dc.ipv4_addr, dc.ipv6_addr, dc.auth) for dc in session.dcs], ) if user := session.user: c.execute( @@ -188,7 +188,8 @@ class SqliteSession(Storage): ); create table datacenter( id integer primary key, - addr text not null, + ipv4_addr text, + ipv6_addr text, auth blob ); create table user( diff --git a/client/src/telethon/_impl/tl/core/serializable.py b/client/src/telethon/_impl/tl/core/serializable.py index acadc89b..2bd061df 100644 --- a/client/src/telethon/_impl/tl/core/serializable.py +++ b/client/src/telethon/_impl/tl/core/serializable.py @@ -1,10 +1,20 @@ import abc import struct -from typing import Self, Tuple +from typing import Protocol, Self, Tuple from .reader import Reader +class HasSlots(Protocol): + __slots__: Tuple[str, ...] + + +def obj_repr(obj: HasSlots) -> str: + fields = ((attr, getattr(obj, attr)) for attr in obj.__slots__) + params = ", ".join(f"{name}={field!r}" for name, field in fields) + return f"{obj.__class__.__name__}({params})" + + class Serializable(abc.ABC): __slots__: Tuple[str, ...] = () @@ -34,10 +44,7 @@ class Serializable(abc.ABC): self._write_boxed_to(buffer) return bytes(buffer) - def __repr__(self) -> str: - fields = ((attr, getattr(self, attr)) for attr in self.__slots__) - attrs = ", ".join(f"{name}={field!r}" for name, field in fields) - return f"{self.__class__.__name__}({attrs})" + __repr__ = obj_repr def __eq__(self, other: object) -> bool: if not isinstance(other, self.__class__):