Telethon/telethon/_client/telegrambaseclient.py
2022-01-24 11:05:58 +01:00

471 lines
17 KiB
Python

import abc
import re
import asyncio
import collections
import logging
import platform
import time
import typing
import ipaddress
import dataclasses
from .. import version, __name__ as __base_name__, _tl
from .._crypto import rsa
from .._misc import markdown, enums, helpers
from .._network import MTProtoSender, Connection, transports
from .._sessions import Session, SQLiteSession, MemorySession
from .._sessions.types import DataCenter, SessionState, EntityType, ChannelState
from .._updates import EntityCache, MessageBox
DEFAULT_DC_ID = 2
DEFAULT_IPV4_IP = '149.154.167.51'
DEFAULT_IPV6_IP = '2001:67c:4e8:f002::a'
DEFAULT_PORT = 443
if typing.TYPE_CHECKING:
from .telegramclient import TelegramClient
_base_log = logging.getLogger(__base_name__)
# In seconds, how long to wait before disconnecting a exported sender.
_DISCONNECT_EXPORTED_AFTER = 60
class _ExportState:
def __init__(self):
# ``n`` is the amount of borrows a given sender has;
# once ``n`` reaches ``0``, disconnect the sender after a while.
self._n = 0
self._zero_ts = 0
self._connected = False
def add_borrow(self):
self._n += 1
self._connected = True
def add_return(self):
self._n -= 1
assert self._n >= 0, 'returned sender more than it was borrowed'
if self._n == 0:
self._zero_ts = time.time()
def should_disconnect(self):
return (self._n == 0
and self._connected
and (time.time() - self._zero_ts) > _DISCONNECT_EXPORTED_AFTER)
def need_connect(self):
return not self._connected
def mark_disconnected(self):
assert self.should_disconnect(), 'marked as disconnected when it was borrowed'
self._connected = False
# TODO How hard would it be to support both `trio` and `asyncio`?
def init(
self: 'TelegramClient',
session: 'typing.Union[str, Session]',
api_id: int,
api_hash: str,
*,
# Logging.
base_logger: typing.Union[str, logging.Logger] = None,
# Connection parameters.
use_ipv6: bool = False,
proxy: typing.Union[tuple, dict] = None,
local_addr: typing.Union[str, tuple] = None,
device_model: str = None,
system_version: str = None,
app_version: str = None,
lang_code: str = 'en',
system_lang_code: str = 'en',
# Nice-to-have.
auto_reconnect: bool = True,
connect_timeout: int = 10,
connect_retries: int = 4,
connect_retry_delay: int = 1,
request_retries: int = 4,
flood_sleep_threshold: int = 60,
# Update handling.
catch_up: bool = False,
receive_updates: bool = True,
max_queued_updates: int = 100,
):
# Logging.
if isinstance(base_logger, str):
base_logger = logging.getLogger(base_logger)
elif not isinstance(base_logger, logging.Logger):
base_logger = _base_log
class _Loggers(dict):
def __missing__(self, key):
if key.startswith("telethon."):
key = key.split('.', maxsplit=1)[1]
return base_logger.getChild(key)
self._log = _Loggers()
# Sessions.
if isinstance(session, str) or session is None:
try:
session = SQLiteSession(session)
except ImportError:
import warnings
warnings.warn(
'The sqlite3 module is not available under this '
'Python installation and no _ session '
'instance was given; using MemorySession.\n'
'You will need to re-login every time unless '
'you use another session storage'
)
session = MemorySession()
elif not isinstance(session, Session):
raise TypeError(
'The given session must be a str or a Session instance.'
)
self._session = session
# In-memory copy of the session's state to avoid a roundtrip as it contains commonly-accessed values.
self._session_state = None
# Nice-to-have.
self._request_retries = request_retries
self._connect_retries = connect_retries
self._connect_retry_delay = connect_retry_delay or 0
self._connect_timeout = connect_timeout
self.flood_sleep_threshold = flood_sleep_threshold
self._flood_waited_requests = {} # prevent calls that would floodwait entirely
self._parse_mode = markdown
# Update handling.
self._catch_up = catch_up
self._no_updates = not receive_updates
self._updates_queue = asyncio.Queue(maxsize=max_queued_updates)
self._updates_handle = None
self._message_box = MessageBox()
self._entity_cache = EntityCache() # required for proper update handling (to know when to getDifference)
# Connection parameters.
if not api_id or not api_hash:
raise ValueError(
"Your API ID or Hash cannot be empty or None. "
"Refer to telethon.rtfd.io for more information.")
if local_addr is not None:
if use_ipv6 is False and ':' in local_addr:
raise TypeError('A local IPv6 address must only be used with `use_ipv6=True`.')
elif use_ipv6 is True and ':' not in local_addr:
raise TypeError('`use_ipv6=True` must only be used with a local IPv6 address.')
self._transport = transports.Full()
self._use_ipv6 = use_ipv6
self._local_addr = local_addr
self._proxy = proxy
self._auto_reconnect = auto_reconnect
self._api_id = int(api_id)
self._api_hash = api_hash
# Used on connection. Capture the variables in a lambda since
# exporting clients need to create this InvokeWithLayer.
system = platform.uname()
if system.machine in ('x86_64', 'AMD64'):
default_device_model = 'PC 64bit'
elif system.machine in ('i386','i686','x86'):
default_device_model = 'PC 32bit'
else:
default_device_model = system.machine
default_system_version = re.sub(r'-.+','',system.release)
self._init_request = _tl.fn.InitConnection(
api_id=self._api_id,
device_model=device_model or default_device_model or 'Unknown',
system_version=system_version or default_system_version or '1.0',
app_version=app_version or self.__version__,
lang_code=lang_code,
system_lang_code=system_lang_code,
lang_pack='', # "langPacks are for official apps only"
query=None,
proxy=None
)
self._sender = MTProtoSender(
loggers=self._log,
retries=self._connect_retries,
delay=self._connect_retry_delay,
auto_reconnect=self._auto_reconnect,
connect_timeout=self._connect_timeout,
updates_queue=self._updates_queue,
)
# Cache ``{dc_id: (_ExportState, MTProtoSender)}`` for all borrowed senders.
self._borrowed_senders = {}
self._borrow_sender_lock = asyncio.Lock()
def get_flood_sleep_threshold(self):
return self._flood_sleep_threshold
def set_flood_sleep_threshold(self, value):
# None -> 0, negative values don't really matter
self._flood_sleep_threshold = min(value or 0, 24 * 60 * 60)
async def connect(self: 'TelegramClient') -> None:
all_dcs = {dc.id: dc for dc in await self._session.get_all_dc()}
self._session_state = await self._session.get_state()
if self._session_state is None:
try_fetch_user = False
self._session_state = SessionState(
user_id=0,
dc_id=DEFAULT_DC_ID,
bot=False,
pts=0,
qts=0,
date=0,
seq=0,
takeout_id=None,
)
else:
try_fetch_user = self._session_state.user_id == 0
if self._catch_up:
channel_states = await self._session.get_all_channel_states()
self._message_box.load(self._session_state, channel_states)
for state in channel_states:
entity = await self._session.get_entity(EntityType.CHANNEL, state.channel_id)
if entity:
self._entity_cache.put(entity)
dc = all_dcs.get(self._session_state.dc_id)
if dc is None:
dc = DataCenter(
id=DEFAULT_DC_ID,
ipv4=None if self._use_ipv6 else int(ipaddress.ip_address(DEFAULT_IPV4_IP)),
ipv6=int(ipaddress.ip_address(DEFAULT_IPV6_IP)) if self._use_ipv6 else None,
port=DEFAULT_PORT,
auth=b'',
)
all_dcs[dc.id] = dc
# Use known key, if any
self._sender.auth_key.key = dc.auth
if not await self._sender.connect(Connection(
ip=str(ipaddress.ip_address((self._use_ipv6 and dc.ipv6) or dc.ipv4)),
port=dc.port,
transport=self._transport.recreate_fresh(),
loggers=self._log,
local_addr=self._local_addr,
)):
# We don't want to init or modify anything if we were already connected
return
if self._sender.auth_key.key != dc.auth:
all_dcs[dc.id] = dc = dataclasses.replace(dc, auth=self._sender.auth_key.key)
# Need to send invokeWithLayer for things to work out.
# Make the most out of this opportunity by also refreshing our state.
# During the v1 to v2 migration, this also correctly sets the IPv* columns.
self._init_request.query = _tl.fn.help.GetConfig()
config = await self._sender.send(_tl.fn.InvokeWithLayer(
_tl.LAYER, self._init_request
))
for dc in config.dc_options:
if dc.media_only or dc.tcpo_only or dc.cdn:
continue
ip = int(ipaddress.ip_address(dc.ip_address))
if dc.id in all_dcs:
if dc.ipv6:
all_dcs[dc.id] = dataclasses.replace(all_dcs[dc.id], port=dc.port, ipv6=ip)
else:
all_dcs[dc.id] = dataclasses.replace(all_dcs[dc.id], port=dc.port, ipv4=ip)
elif dc.ipv6:
all_dcs[dc.id] = DataCenter(dc.id, None, ip, dc.port, b'')
else:
all_dcs[dc.id] = DataCenter(dc.id, ip, None, dc.port, b'')
for dc in all_dcs.values():
await self._session.insert_dc(dc)
if try_fetch_user:
# If there was a previous session state, but the current user ID is 0, it means we've
# migrated and not yet populated the current user (or the client connected but never
# logged in). Attempt to fetch the user now. If it works, also get the update state.
me = await self.get_me()
if me:
await self._update_session_state(me, save=False)
await self._session.save()
self._updates_handle = asyncio.create_task(self._update_loop())
def is_connected(self: 'TelegramClient') -> bool:
sender = getattr(self, '_sender', None)
return sender and sender.is_connected()
async def disconnect(self: 'TelegramClient'):
return await _disconnect_coro(self)
def set_proxy(self: 'TelegramClient', proxy: typing.Union[tuple, dict]):
init_proxy = None
self._init_request.proxy = init_proxy
self._proxy = proxy
# While `await client.connect()` passes new proxy on each new call,
# auto-reconnect attempts use already set up `_connection` inside
# the `_sender`, so the only way to change proxy between those
# is to directly inject parameters.
connection = getattr(self._sender, "_connection", None)
if connection:
if isinstance(connection, conns.TcpMTProxy):
connection._ip = proxy[0]
connection._port = proxy[1]
else:
connection._proxy = proxy
async def _disconnect_coro(self: 'TelegramClient'):
await _disconnect(self)
# Also clean-up all exported senders because we're done with them
async with self._borrow_sender_lock:
for state, sender in self._borrowed_senders.values():
# Note that we're not checking for `state.should_disconnect()`.
# If the user wants to disconnect the client, ALL connections
# to Telegram (including exported senders) should be closed.
#
# Disconnect should never raise, so there's no try/except.
await sender.disconnect()
# Can't use `mark_disconnected` because it may be borrowed.
state._connected = False
# If any was borrowed
self._borrowed_senders.clear()
async def _disconnect(self: 'TelegramClient'):
"""
Disconnect only, without closing the session. Used in reconnections
to different data centers, where we don't want to close the session
file; user disconnects however should close it since it means that
their job with the client is complete and we should clean it up all.
"""
await self._sender.disconnect()
await helpers._cancel(self._log[__name__], updates_handle=self._updates_handle)
try:
await self._updates_handle
except asyncio.CancelledError:
pass
await self._session.insert_entities(self._entity_cache.get_all_entities())
session_state, channel_states = self._message_box.session_state()
for channel_id, pts in channel_states.items():
await self._session.insert_channel_state(ChannelState(channel_id=channel_id, pts=pts))
await self._replace_session_state(**session_state)
async def _switch_dc(self: 'TelegramClient', new_dc):
"""
Permanently switches the current connection to the new data center.
"""
self._log[__name__].info('Reconnecting to new data center %s', new_dc)
await self._replace_session_state(dc_id=new_dc)
await _disconnect(self)
return await self.connect()
async def _create_exported_sender(self: 'TelegramClient', dc_id):
"""
Creates a new exported `MTProtoSender` for the given `dc_id` and
returns it. This method should be used by `_borrow_exported_sender`.
"""
# Thanks badoualy/kotlogram on /telegram/api/DefaultTelegramClient.kt
# for clearly showing how to export the authorization
dc = next(dc for dc in await self._session.get_all_dc() if dc.id == dc_id)
# Can't reuse self._sender._connection as it has its own seqno.
#
# If one were to do that, Telegram would reset the connection
# with no further clues.
sender = MTProtoSender(loggers=self._log)
await self._sender.connect(Connection(
ip=str(ipaddress.ip_address((self._use_ipv6 and dc.ipv6) or dc.ipv4)),
port=dc.port,
transport=self._transport.recreate_fresh(),
loggers=self._log,
local_addr=self._local_addr,
))
self._log[__name__].info('Exporting auth for new borrowed sender in %s', dc)
auth = await self(_tl.fn.auth.ExportAuthorization(dc_id))
self._init_request.query = _tl.fn.auth.ImportAuthorization(id=auth.id, bytes=auth.bytes)
req = _tl.fn.InvokeWithLayer(_tl.LAYER, self._init_request)
await sender.send(req)
return sender
async def _borrow_exported_sender(self: 'TelegramClient', dc_id):
"""
Borrows a connected `MTProtoSender` for the given `dc_id`.
If it's not cached, creates a new one if it doesn't exist yet,
and imports a freshly exported authorization key for it to be usable.
Once its job is over it should be `_return_exported_sender`.
"""
async with self._borrow_sender_lock:
self._log[__name__].debug('Borrowing sender for dc_id %d', dc_id)
state, sender = self._borrowed_senders.get(dc_id, (None, None))
if state is None:
state = _ExportState()
sender = await _create_exported_sender(self, dc_id)
sender.dc_id = dc_id
self._borrowed_senders[dc_id] = (state, sender)
elif state.need_connect():
dc = next(dc for dc in await self._session.get_all_dc() if dc.id == dc_id)
await self._sender.connect(Connection(
ip=str(ipaddress.ip_address((self._use_ipv6 and dc.ipv6) or dc.ipv4)),
port=dc.port,
transport=self._transport.recreate_fresh(),
loggers=self._log,
local_addr=self._local_addr,
))
state.add_borrow()
return sender
async def _return_exported_sender(self: 'TelegramClient', sender):
"""
Returns a borrowed exported sender. If all borrows have
been returned, the sender is cleanly disconnected.
"""
async with self._borrow_sender_lock:
self._log[__name__].debug('Returning borrowed sender for dc_id %d', sender.dc_id)
state, _ = self._borrowed_senders[sender.dc_id]
state.add_return()
async def _clean_exported_senders(self: 'TelegramClient'):
"""
Cleans-up all unused exported senders by disconnecting them.
"""
async with self._borrow_sender_lock:
for dc_id, (state, sender) in self._borrowed_senders.items():
if state.should_disconnect():
self._log[__name__].info(
'Disconnecting borrowed sender for DC %d', dc_id)
# Disconnect should never raise
await sender.disconnect()
state.mark_disconnected()