Make TelegramBareClient able to invoke requests

This commit is contained in:
Lonami Exo 2018-06-09 21:03:48 +02:00
parent 7e68274f26
commit 3e151a1b7a
5 changed files with 347 additions and 215 deletions

View File

@ -2,7 +2,6 @@ import asyncio
import logging
from . import MTProtoPlainSender, authenticator
from .connection import ConnectionTcpFull
from .. import utils
from ..errors import (
BadMessageError, TypeNotFoundError, BrokenAuthKeyError, SecurityError,
@ -39,12 +38,16 @@ class MTProtoSender:
A new authorization key will be generated on connection if no other
key exists yet.
"""
def __init__(self, state, retries=5):
def __init__(self, state, connection, *, retries=5,
first_query=None, update_callback=None):
self.state = state
self._connection = ConnectionTcpFull()
self._connection = connection
self._ip = None
self._port = None
self._retries = retries
self._first_query = first_query
self._is_first_query = bool(first_query)
self._update_callback = update_callback
# Whether the user has explicitly connected or disconnected.
#
@ -110,6 +113,9 @@ class MTProtoSender:
self._user_connected = True
await self._connect()
def is_connected(self):
return self._user_connected
async def disconnect(self):
"""
Cleanly disconnects the instance from the network, cancels
@ -227,6 +233,12 @@ class MTProtoSender:
__log__.debug('Starting send loop')
self._send_loop_handle = asyncio.ensure_future(self._send_loop())
if self._is_first_query:
__log__.debug('Running first query')
self._is_first_query = False
async with self._send_lock:
self.send(self._first_query)
__log__.debug('Starting receive loop')
self._recv_loop_handle = asyncio.ensure_future(self._recv_loop())
__log__.info('Connection to {} complete!'.format(self._ip))
@ -445,9 +457,8 @@ class MTProtoSender:
async def _handle_update(self, message):
__log__.debug('Handling update {}'
.format(message.obj.__class__.__name__))
# TODO Further handling of the update
# TODO Process entities
if self._update_callback:
self._update_callback(message.obj)
async def _handle_pong(self, message):
"""

View File

@ -1,24 +1,17 @@
import asyncio
import itertools
import logging
import platform
from datetime import timedelta, datetime
from . import version, utils
from . import version, errors, utils
from .crypto import rsa
from .extensions import markdown
from .network import MTProtoSender, ConnectionTcpFull
from .network.mtprotostate import MTProtoState
from .sessions import Session, SQLiteSession
from .tl import TLObject
from .tl import TLObject, types, functions
from .tl.all_tlobjects import LAYER
from .tl.functions import (
InitConnectionRequest, InvokeWithLayerRequest
)
from .tl.functions.auth import (
ImportAuthorizationRequest, ExportAuthorizationRequest
)
from .tl.functions.help import (
GetCdnConfigRequest, GetConfigRequest
)
from .tl.types.auth import ExportedAuthorization
from .update_state import UpdateState
DEFAULT_DC_ID = 4
@ -29,7 +22,6 @@ DEFAULT_PORT = 443
__log__ = logging.getLogger(__name__)
# TODO Do we need this class?
class TelegramBareClient:
"""
A bare Telegram client that somewhat eases the usage of the
@ -71,23 +63,10 @@ class TelegramBareClient:
A tuple consisting of ``(socks.SOCKS5, 'host', port)``.
See https://github.com/Anorov/PySocks#usage-1 for more.
update_workers (`int`, optional):
If specified, represents how many extra threads should
be spawned to handle incoming updates, and updates will
be kept in memory until they are processed. Note that
you must set this to at least ``0`` if you want to be
able to process updates through :meth:`updates.poll()`.
timeout (`int` | `float` | `timedelta`, optional):
The timeout to be used when receiving responses from
the network. Defaults to 5 seconds.
spawn_read_thread (`bool`, optional):
Whether to use an extra background thread or not. Defaults
to ``True`` so receiving items from the network happens
instantly, as soon as they arrive. Can still be disabled
if you want to run the library without any additional thread.
report_errors (`bool`, optional):
Whether to report RPC errors or not. Defaults to ``True``,
see :ref:`api-status` for more information.
@ -170,7 +149,22 @@ class TelegramBareClient:
if isinstance(connection, type):
connection = connection(proxy=proxy, timeout=timeout)
self._sender = MTProtoSender(self.session, connection)
# Used on connection - the user may modify these and reconnect
system = platform.uname()
state = MTProtoState(self.session.auth_key)
first = functions.InvokeWithLayerRequest(
LAYER, functions.InitConnectionRequest(
api_id=self.api_id,
device_model=device_model or system.system or 'Unknown',
system_version=system_version or system.release or '1.0',
app_version=app_version or self.__version__,
lang_code=lang_code,
system_lang_code=system_lang_code,
lang_pack='', # "langPacks are for official apps only"
query=functions.help.GetConfigRequest()
)
)
self._sender = MTProtoSender(state, connection, first_query=first)
# Cache "exported" sessions as 'dc_id: Session' not to recreate
# them all the time since generating a new key is a relatively
@ -179,16 +173,7 @@ class TelegramBareClient:
# This member will process updates if enabled.
# One may change self.updates.enabled at any later point.
# TODO Stop using that 1
self.updates = UpdateState(1)
# Used on connection - the user may modify these and reconnect
system = platform.uname()
self.device_model = device_model or system.system or 'Unknown'
self.system_version = system_version or system.release or '1.0'
self.app_version = app_version or self.__version__
self.lang_code = lang_code
self.system_lang_code = system_lang_code
self.updates = UpdateState()
# Save whether the user is authorized here (a.k.a. logged in)
self._authorized = None # None = We don't know yet
@ -229,14 +214,10 @@ class TelegramBareClient:
# region Connecting
async def connect(self, _sync_updates=True):
async def connect(self):
"""
Connects to Telegram.
"""
# TODO Maybe we should rethink what the session does if the sender
# needs a session but it might connect to arbitrary IPs?
#
# TODO sync updates/connected and authorized if no UnauthorizedError?
await self._sender.connect(
self.session.server_address, self.session.port)
@ -246,22 +227,6 @@ class TelegramBareClient:
"""
return self._sender.is_connected()
def _wrap_init_connection(self, query):
"""
Wraps `query` around
``InvokeWithLayerRequest(InitConnectionRequest(...))``.
"""
return InvokeWithLayerRequest(LAYER, InitConnectionRequest(
api_id=self.api_id,
device_model=self.device_model,
system_version=self.system_version,
app_version=self.app_version,
lang_code=self.lang_code,
system_lang_code=self.system_lang_code,
lang_pack='', # "langPacks are for official apps only"
query=query
))
async def disconnect(self):
"""
Disconnects from Telegram.
@ -273,7 +238,7 @@ class TelegramBareClient:
def _switch_dc(self, new_dc):
"""
Switches the current connection to the new data center.
Permanently switches the current connection to the new data center.
"""
# TODO Implement
raise NotImplementedError
@ -288,44 +253,39 @@ class TelegramBareClient:
self.disconnect()
return self.connect()
def set_proxy(self, proxy):
"""Change the proxy used by the connections.
"""
if self.is_connected():
raise RuntimeError("You can't change the proxy while connected.")
# TODO Should we tell the user to create a new client?
# Can this be done more cleanly? Similar to `switch_dc`
self._sender._connection.conn.proxy = proxy
# endregion
# region Working with different connections/Data Centers
def _get_dc(self, dc_id, cdn=False):
async def _get_dc(self, dc_id, cdn=False):
"""Gets the Data Center (DC) associated to 'dc_id'"""
if not TelegramBareClient._config:
TelegramBareClient._config = self(GetConfigRequest())
TelegramBareClient._config =\
await self(functions.help.GetConfigRequest())
try:
if cdn:
# Ensure we have the latest keys for the CDNs
for pk in self(GetCdnConfigRequest()).public_keys:
result = await self(functions.help.GetCdnConfigRequest())
for pk in result.public_keys:
rsa.add_key(pk.public_key)
return next(
dc for dc in TelegramBareClient._config.dc_options
if dc.id == dc_id and bool(dc.ipv6) == self._use_ipv6 and bool(dc.cdn) == cdn
if dc.id == dc_id
and bool(dc.ipv6) == self._use_ipv6 and bool(dc.cdn) == cdn
)
except StopIteration:
if not cdn:
raise
# New configuration, perhaps a new CDN was added?
TelegramBareClient._config = self(GetConfigRequest())
TelegramBareClient._config =\
await self(functions.help.GetConfigRequest())
return self._get_dc(dc_id, cdn=cdn)
def _get_exported_client(self, dc_id):
async def _get_exported_client(self, dc_id):
"""Creates and connects a new TelegramBareClient for the desired DC.
If it's the first time calling the method with a given dc_id,
@ -333,6 +293,8 @@ class TelegramBareClient:
Exporting/Importing the authorization will also be done so that
the auth is bound with the key.
"""
# TODO Implement
raise NotImplementedError
# Thanks badoualy/kotlogram on /telegram/api/DefaultTelegramClient.kt
# for clearly showing how to export the authorization! ^^
session = self._exported_sessions.get(dc_id)
@ -346,7 +308,8 @@ class TelegramBareClient:
# Export the current authorization to the new DC.
__log__.info('Exporting authorization for data center %s', dc)
export_auth = self(ExportAuthorizationRequest(dc_id))
export_auth =\
await self(functions.auth.ExportAuthorizationRequest(dc_id))
# Create a temporary session for this IP address, which needs
# to be different because each auth_key is unique per DC.
@ -374,11 +337,13 @@ class TelegramBareClient:
client._authorized = True # We exported the auth, so we got auth
return client
def _get_cdn_client(self, cdn_redirect):
async def _get_cdn_client(self, cdn_redirect):
"""Similar to ._get_exported_client, but for CDNs"""
# TODO Implement
raise NotImplementedError
session = self._exported_sessions.get(cdn_redirect.dc_id)
if not session:
dc = self._get_dc(cdn_redirect.dc_id, cdn=True)
dc = await self._get_dc(cdn_redirect.dc_id, cdn=True)
session = self.session.clone()
session.set_dc(dc.id, dc.ip_address, dc.port)
self._exported_sessions[cdn_redirect.dc_id] = session
@ -403,7 +368,7 @@ class TelegramBareClient:
# region Invoking Telegram requests
async def __call__(self, request, ordered=False):
async def __call__(self, request, retries=5, ordered=False):
"""
Invokes (sends) one or more MTProtoRequests and returns (receives)
their result.
@ -412,13 +377,6 @@ class TelegramBareClient:
request (`TLObject` | `list`):
The request or requests to be invoked.
retries (`bool`, optional):
How many times the request should be retried automatically
in case it fails with a non-RPC error.
The invoke will be retried up to 'retries' times before raising
``RuntimeError``.
ordered (`bool`, optional):
Whether the requests (if more than one was given) should be
executed sequentially on the server. They run in arbitrary
@ -433,25 +391,257 @@ class TelegramBareClient:
x.content_related for x in requests):
raise TypeError('You can only invoke requests, not types!')
# TODO Resolve requests, should be done by TelegramClient
# for r in requests:
# await r.resolve(self, utils)
for r in requests:
await r.resolve(self, utils)
# TODO InvokeWithLayer if no authkey, maybe done in MTProtoSender?
# TODO Handle PhoneMigrateError, NetworkMigrateError, UserMigrateError
# ^ by switching DC
# TODO Retry on ServerError, RpcCallFailError
# TODO Auto-sleep on some FloodWaitError, FloodTestPhoneWaitError
future = await self._sender.send(request, ordered=ordered)
if isinstance(future, list):
results = []
for f in future:
results.append(await future)
return results
else:
return await future
for _ in range(retries):
try:
future = self._sender.send(request, ordered=ordered)
if isinstance(future, list):
results = []
for f in future:
results.append(await f)
return results
else:
return await future
except (errors.ServerError, errors.RpcCallFailError):
pass
except (errors.FloodWaitError, errors.FloodTestPhoneWaitError) as e:
if e.seconds <= self.session.flood_sleep_threshold:
await asyncio.sleep(e.seconds)
else:
raise
except (errors.PhoneMigrateError, errors.NetworkMigrateError,
errors.UserMigrateError) as e:
await self._switch_dc(e.new_dc)
raise ValueError('Number of retries reached 0')
# Let people use client.invoke(SomeRequest()) instead client(...)
invoke = __call__
# endregion
# region Minimal helpers
async def get_me(self, input_peer=False):
"""
Gets "me" (the self user) which is currently authenticated,
or None if the request fails (hence, not authenticated).
Args:
input_peer (`bool`, optional):
Whether to return the :tl:`InputPeerUser` version or the normal
:tl:`User`. This can be useful if you just need to know the ID
of yourself.
Returns:
Your own :tl:`User`.
"""
if input_peer and self._self_input_peer:
return self._self_input_peer
try:
me = (await self(
functions.users.GetUsersRequest([types.InputUserSelf()])))[0]
if not self._self_input_peer:
self._self_input_peer = utils.get_input_peer(
me, allow_self=False
)
return self._self_input_peer if input_peer else me
except errors.UnauthorizedError:
return None
async def get_entity(self, entity):
"""
Turns the given entity into a valid Telegram user or chat.
entity (`str` | `int` | :tl:`Peer` | :tl:`InputPeer`):
The entity (or iterable of entities) to be transformed.
If it's a string which can be converted to an integer or starts
with '+' it will be resolved as if it were a phone number.
If it doesn't start with '+' or starts with a '@' it will be
be resolved from the username. If no exact match is returned,
an error will be raised.
If the entity is an integer or a Peer, its information will be
returned through a call to self.get_input_peer(entity).
If the entity is neither, and it's not a TLObject, an
error will be raised.
Returns:
:tl:`User`, :tl:`Chat` or :tl:`Channel` corresponding to the
input entity. A list will be returned if more than one was given.
"""
single = not utils.is_list_like(entity)
if single:
entity = (entity,)
# Group input entities by string (resolve username),
# input users (get users), input chat (get chats) and
# input channels (get channels) to get the most entities
# in the less amount of calls possible.
inputs = [
x if isinstance(x, str) else await self.get_input_entity(x)
for x in entity
]
users = [x for x in inputs
if isinstance(x, (types.InputPeerUser, types.InputPeerSelf))]
chats = [x.chat_id for x in inputs
if isinstance(x, types.InputPeerChat)]
channels = [x for x in inputs
if isinstance(x, types.InputPeerChannel)]
if users:
# GetUsersRequest has a limit of 200 per call
tmp = []
while users:
curr, users = users[:200], users[200:]
tmp.extend(await self(functions.users.GetUsersRequest(curr)))
users = tmp
if chats: # TODO Handle chats slice?
chats = (await self(
functions.messages.GetChatsRequest(chats))).chats
if channels:
channels = (await self(
functions.channels.GetChannelsRequest(channels))).chats
# Merge users, chats and channels into a single dictionary
id_entity = {
utils.get_peer_id(x): x
for x in itertools.chain(users, chats, channels)
}
# We could check saved usernames and put them into the users,
# chats and channels list from before. While this would reduce
# the amount of ResolveUsername calls, it would fail to catch
# username changes.
result = [
await self._get_entity_from_string(x) if isinstance(x, str)
else (
id_entity[utils.get_peer_id(x)]
if not isinstance(x, types.InputPeerSelf)
else next(u for u in id_entity.values()
if isinstance(u, types.User) and u.is_self)
)
for x in inputs
]
return result[0] if single else result
async def get_input_entity(self, peer):
"""
Turns the given peer into its input entity version. Most requests
use this kind of InputUser, InputChat and so on, so this is the
most suitable call to make for those cases.
entity (`str` | `int` | :tl:`Peer` | :tl:`InputPeer`):
The integer ID of an user or otherwise either of a
:tl:`PeerUser`, :tl:`PeerChat` or :tl:`PeerChannel`, for
which to get its ``Input*`` version.
If this ``Peer`` hasn't been seen before by the library, the top
dialogs will be loaded and their entities saved to the session
file (unless this feature was disabled explicitly).
If in the end the access hash required for the peer was not found,
a ValueError will be raised.
Returns:
:tl:`InputPeerUser`, :tl:`InputPeerChat` or :tl:`InputPeerChannel`
or :tl:`InputPeerSelf` if the parameter is ``'me'`` or ``'self'``.
If you need to get the ID of yourself, you should use
`get_me` with ``input_peer=True``) instead.
"""
if peer in ('me', 'self'):
return types.InputPeerSelf()
try:
# First try to get the entity from cache, otherwise figure it out
return self.session.get_input_entity(peer)
except ValueError:
pass
if isinstance(peer, str):
return utils.get_input_peer(
await self._get_entity_from_string(peer))
if not isinstance(peer, int) and (not isinstance(peer, TLObject)
or peer.SUBCLASS_OF_ID != 0x2d45687):
# Try casting the object into an input peer. Might TypeError.
# Don't do it if a not-found ID was given (instead ValueError).
# Also ignore Peer (0x2d45687 == crc32(b'Peer'))'s, lacking hash.
return utils.get_input_peer(peer)
raise ValueError(
'Could not find the input entity for "{}". Please read https://'
'telethon.readthedocs.io/en/latest/extra/basic/entities.html to'
' find out more details.'
.format(peer)
)
# endregion
# region Private methods
async def _get_entity_from_string(self, string):
"""
Gets a full entity from the given string, which may be a phone or
an username, and processes all the found entities on the session.
The string may also be a user link, or a channel/chat invite link.
This method has the side effect of adding the found users to the
session database, so it can be queried later without API calls,
if this option is enabled on the session.
Returns the found entity, or raises TypeError if not found.
"""
phone = utils.parse_phone(string)
if phone:
for user in (await self(
functions.contacts.GetContactsRequest(0))).users:
if user.phone == phone:
return user
else:
username, is_join_chat = utils.parse_username(string)
if is_join_chat:
invite = await self(
functions.messages.CheckChatInviteRequest(username))
if isinstance(invite, types.ChatInvite):
raise ValueError(
'Cannot get entity from a channel (or group) '
'that you are not part of. Join the group and retry'
)
elif isinstance(invite, types.ChatInviteAlready):
return invite.chat
elif username:
if username in ('me', 'self'):
return await self.get_me()
try:
result = await self(
functions.contacts.ResolveUsernameRequest(username))
except errors.UsernameNotOccupiedError as e:
raise ValueError('No user has "{}" as username'
.format(username)) from e
for entity in itertools.chain(result.users, result.chats):
if getattr(entity, 'username', None) or '' \
.lower() == username:
return entity
try:
# Nobody with this username, maybe it's an exact name/title
return await self.get_entity(
self.session.get_input_entity(string))
except ValueError:
pass
raise ValueError(
'Cannot find any entity corresponding to "{}"'.format(string)
)
# endregion

View File

@ -1,6 +1,5 @@
import struct
from datetime import datetime, date
from threading import Event
class TLObject:
@ -155,7 +154,7 @@ class TLObject:
return TLObject.pretty_format(self, indent=0)
# These should be overrode
def resolve(self, client, utils):
async def resolve(self, client, utils):
pass
def to_dict(self):

View File

@ -17,17 +17,7 @@ class UpdateState:
"""
WORKER_POLL_TIMEOUT = 5.0 # Avoid waiting forever on the workers
def __init__(self, workers=None):
"""
:param workers: This integer parameter has three possible cases:
workers is None: Updates will *not* be stored on self.
workers = 0: Another thread is responsible for calling self.poll()
workers > 0: 'workers' background threads will be spawned, any
any of them will invoke the self.handler.
"""
self._workers = workers
self._worker_threads = []
def __init__(self):
self.handler = None
self._updates_lock = RLock()
self._updates = Queue()
@ -50,66 +40,6 @@ class UpdateState:
except Empty:
return None
def get_workers(self):
return self._workers
def set_workers(self, n):
"""Changes the number of workers running.
If 'n is None', clears all pending updates from memory.
"""
if n is None:
self.stop_workers()
else:
self._workers = n
self.setup_workers()
workers = property(fget=get_workers, fset=set_workers)
def stop_workers(self):
"""
Waits for all the worker threads to stop.
"""
# Put dummy ``None`` objects so that they don't need to timeout.
n = self._workers
self._workers = None
if n:
with self._updates_lock:
for _ in range(n):
self._updates.put(None)
for t in self._worker_threads:
t.join()
self._worker_threads.clear()
self._workers = n
def setup_workers(self):
if self._worker_threads or not self._workers:
# There already are workers, or workers is None or 0. Do nothing.
return
for i in range(self._workers):
thread = Thread(
target=UpdateState._worker_loop,
name='UpdateWorker{}'.format(i),
daemon=True,
args=(self, i)
)
self._worker_threads.append(thread)
thread.start()
def _worker_loop(self, wid):
while self._workers is not None:
try:
update = self.poll(timeout=UpdateState.WORKER_POLL_TIMEOUT)
if update and self.handler:
self.handler(update)
except StopIteration:
break
except:
# We don't want to crash a worker thread due to any reason
__log__.exception('Unhandled exception on worker %d', wid)
def get_update_state(self, entity_id):
"""Gets the updates.State corresponding to the given entity or 0."""
return self._state
@ -118,35 +48,32 @@ class UpdateState:
"""Processes an update object. This method is normally called by
the library itself.
"""
if self._workers is None:
return # No processing needs to be done if nobody's working
if isinstance(update, tl.updates.State):
__log__.debug('Saved new updates state')
self._state = update
return # Nothing else to be done
with self._updates_lock:
if isinstance(update, tl.updates.State):
__log__.debug('Saved new updates state')
self._state = update
return # Nothing else to be done
if hasattr(update, 'pts'):
self._state.pts = update.pts
if hasattr(update, 'pts'):
self._state.pts = update.pts
# After running the script for over an hour and receiving over
# 1000 updates, the only duplicates received were users going
# online or offline. We can trust the server until new reports.
# This should only be used as read-only.
if isinstance(update, tl.UpdateShort):
update.update._entities = {}
self._updates.put(update.update)
# After running the script for over an hour and receiving over
# 1000 updates, the only duplicates received were users going
# online or offline. We can trust the server until new reports.
# This should only be used as read-only.
if isinstance(update, tl.UpdateShort):
update.update._entities = {}
self._updates.put(update.update)
# Expand "Updates" into "Update", and pass these to callbacks.
# Since .users and .chats have already been processed, we
# don't need to care about those either.
elif isinstance(update, (tl.Updates, tl.UpdatesCombined)):
entities = {utils.get_peer_id(x): x for x in
itertools.chain(update.users, update.chats)}
for u in update.updates:
u._entities = entities
self._updates.put(u)
# TODO Handle "tl.UpdatesTooLong"
else:
update._entities = {}
self._updates.put(update)
# Expand "Updates" into "Update", and pass these to callbacks.
# Since .users and .chats have already been processed, we
# don't need to care about those either.
elif isinstance(update, (tl.Updates, tl.UpdatesCombined)):
entities = {utils.get_peer_id(x): x for x in
itertools.chain(update.users, update.chats)}
for u in update.updates:
u._entities = entities
self._updates.put(u)
# TODO Handle "tl.UpdatesTooLong"
else:
update._entities = {}
self._updates.put(update)

View File

@ -14,10 +14,15 @@ AUTO_GEN_NOTICE = \
AUTO_CASTS = {
'InputPeer': 'utils.get_input_peer(client.get_input_entity({}))',
'InputChannel': 'utils.get_input_channel(client.get_input_entity({}))',
'InputUser': 'utils.get_input_user(client.get_input_entity({}))',
'InputDialogPeer': 'utils.get_input_dialog(client.get_input_entity({}))',
'InputPeer':
'utils.get_input_peer(await client.get_input_entity({}))',
'InputChannel':
'utils.get_input_channel(await client.get_input_entity({}))',
'InputUser':
'utils.get_input_user(await client.get_input_entity({}))',
'InputDialogPeer':
'utils.get_input_dialog(await client.get_input_entity({}))',
'InputMedia': 'utils.get_input_media({})',
'InputPhoto': 'utils.get_input_photo({})',
'InputMessage': 'utils.get_input_message({})'
@ -234,7 +239,7 @@ def _write_class_init(tlobject, type_constructors, builder):
def _write_resolve(tlobject, builder):
if any(arg.type in AUTO_CASTS for arg in tlobject.real_args):
builder.writeln('def resolve(self, client, utils):')
builder.writeln('async def resolve(self, client, utils):')
for arg in tlobject.real_args:
ac = AUTO_CASTS.get(arg.type, None)
if not ac: