mirror of
https://github.com/LonamiWebs/Telethon.git
synced 2025-03-09 21:55:48 +03:00
Merge pull request #370 from andr-04/asyncio
Made update system for asyncio functional
This commit is contained in:
commit
6dc0ee9d6c
|
@ -5,13 +5,18 @@ import socket
|
|||
from datetime import timedelta
|
||||
from io import BytesIO, BufferedWriter
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
MAX_TIMEOUT = 15 # in seconds
|
||||
CONN_RESET_ERRNOS = {
|
||||
errno.EBADF, errno.ENOTSOCK, errno.ENETUNREACH,
|
||||
errno.EINVAL, errno.ENOTCONN
|
||||
}
|
||||
|
||||
|
||||
class TcpClient:
|
||||
def __init__(self, proxy=None, timeout=timedelta(seconds=5)):
|
||||
def __init__(self, proxy=None, timeout=timedelta(seconds=5), loop=None):
|
||||
self.proxy = proxy
|
||||
self._socket = None
|
||||
self._loop = loop if loop else asyncio.get_event_loop()
|
||||
|
||||
if isinstance(timeout, timedelta):
|
||||
self.timeout = timeout.seconds
|
||||
|
@ -31,7 +36,7 @@ class TcpClient:
|
|||
else: # tuple, list, etc.
|
||||
self._socket.set_proxy(*self.proxy)
|
||||
|
||||
self._socket.settimeout(self.timeout)
|
||||
self._socket.setblocking(False)
|
||||
|
||||
async def connect(self, ip, port):
|
||||
"""Connects to the specified IP and port number.
|
||||
|
@ -42,20 +47,27 @@ class TcpClient:
|
|||
else:
|
||||
mode, address = socket.AF_INET, (ip, port)
|
||||
|
||||
timeout = 1
|
||||
while True:
|
||||
try:
|
||||
while not self._socket:
|
||||
if not self._socket:
|
||||
self._recreate_socket(mode)
|
||||
|
||||
await loop.sock_connect(self._socket, address)
|
||||
await self._loop.sock_connect(self._socket, address)
|
||||
break # Successful connection, stop retrying to connect
|
||||
except ConnectionError:
|
||||
self._socket = None
|
||||
await asyncio.sleep(timeout)
|
||||
timeout = min(timeout * 2, MAX_TIMEOUT)
|
||||
except OSError as e:
|
||||
# There are some errors that we know how to handle, and
|
||||
# the loop will allow us to retry
|
||||
if e.errno == errno.EBADF:
|
||||
if e.errno in (errno.EBADF, errno.ENOTSOCK, errno.EINVAL):
|
||||
# Bad file descriptor, i.e. socket was closed, set it
|
||||
# to none to recreate it on the next iteration
|
||||
self._socket = None
|
||||
await asyncio.sleep(timeout)
|
||||
timeout = min(timeout * 2, MAX_TIMEOUT)
|
||||
else:
|
||||
raise
|
||||
|
||||
|
@ -81,13 +93,14 @@ class TcpClient:
|
|||
raise ConnectionResetError()
|
||||
|
||||
try:
|
||||
await loop.sock_sendall(self._socket, data)
|
||||
except socket.timeout as e:
|
||||
await asyncio.wait_for(self._loop.sock_sendall(self._socket, data),
|
||||
timeout=self.timeout, loop=self._loop)
|
||||
except asyncio.TimeoutError as e:
|
||||
raise TimeoutError() from e
|
||||
except BrokenPipeError:
|
||||
self._raise_connection_reset()
|
||||
except OSError as e:
|
||||
if e.errno == errno.EBADF:
|
||||
if e.errno in CONN_RESET_ERRNOS:
|
||||
self._raise_connection_reset()
|
||||
else:
|
||||
raise
|
||||
|
@ -104,11 +117,12 @@ class TcpClient:
|
|||
bytes_left = size
|
||||
while bytes_left != 0:
|
||||
try:
|
||||
partial = await loop.sock_recv(self._socket, bytes_left)
|
||||
except socket.timeout as e:
|
||||
partial = await asyncio.wait_for(self._loop.sock_recv(self._socket, bytes_left),
|
||||
timeout=self.timeout, loop=self._loop)
|
||||
except asyncio.TimeoutError as e:
|
||||
raise TimeoutError() from e
|
||||
except OSError as e:
|
||||
if e.errno == errno.EBADF or e.errno == errno.ENOTSOCK:
|
||||
if e.errno in CONN_RESET_ERRNOS:
|
||||
self._raise_connection_reset()
|
||||
else:
|
||||
raise
|
||||
|
|
|
@ -43,13 +43,13 @@ class Connection:
|
|||
"""
|
||||
|
||||
def __init__(self, mode=ConnectionMode.TCP_FULL,
|
||||
proxy=None, timeout=timedelta(seconds=5)):
|
||||
proxy=None, timeout=timedelta(seconds=5), loop=None):
|
||||
self._mode = mode
|
||||
self._send_counter = 0
|
||||
self._aes_encrypt, self._aes_decrypt = None, None
|
||||
|
||||
# TODO Rename "TcpClient" as some sort of generic socket?
|
||||
self.conn = TcpClient(proxy=proxy, timeout=timeout)
|
||||
self.conn = TcpClient(proxy=proxy, timeout=timeout, loop=loop)
|
||||
|
||||
# Sending messages
|
||||
if mode == ConnectionMode.TCP_FULL:
|
||||
|
@ -206,7 +206,7 @@ class Connection:
|
|||
return await self.conn.read(length)
|
||||
|
||||
async def _read_obfuscated(self, length):
|
||||
return await self._aes_decrypt.encrypt(self.conn.read(length))
|
||||
return self._aes_decrypt.encrypt(await self.conn.read(length))
|
||||
|
||||
# endregion
|
||||
|
||||
|
|
|
@ -1,6 +1,8 @@
|
|||
import gzip
|
||||
import logging
|
||||
import struct
|
||||
import asyncio
|
||||
from asyncio import Event
|
||||
|
||||
from .. import helpers as utils
|
||||
from ..crypto import AES
|
||||
|
@ -30,17 +32,15 @@ class MtProtoSender:
|
|||
in parallel, so thread-safety (hence locking) isn't needed.
|
||||
"""
|
||||
|
||||
def __init__(self, session, connection):
|
||||
def __init__(self, session, connection, loop=None):
|
||||
"""Creates a new MtProtoSender configured to send messages through
|
||||
'connection' and using the parameters from 'session'.
|
||||
"""
|
||||
self.session = session
|
||||
self.connection = connection
|
||||
self._loop = loop if loop else asyncio.get_event_loop()
|
||||
self._logger = logging.getLogger(__name__)
|
||||
|
||||
# Message IDs that need confirmation
|
||||
self._need_confirmation = set()
|
||||
|
||||
# Requests (as msg_id: Message) sent waiting to be received
|
||||
self._pending_receive = {}
|
||||
|
||||
|
@ -54,12 +54,11 @@ class MtProtoSender:
|
|||
def disconnect(self):
|
||||
"""Disconnects from the server"""
|
||||
self.connection.close()
|
||||
self._need_confirmation.clear()
|
||||
self._clear_all_pending()
|
||||
|
||||
def clone(self):
|
||||
"""Creates a copy of this MtProtoSender as a new connection"""
|
||||
return MtProtoSender(self.session, self.connection.clone())
|
||||
return MtProtoSender(self.session, self.connection.clone(), self._loop)
|
||||
|
||||
# region Send and receive
|
||||
|
||||
|
@ -67,21 +66,23 @@ class MtProtoSender:
|
|||
"""Sends the specified MTProtoRequest, previously sending any message
|
||||
which needed confirmation."""
|
||||
|
||||
# Prepare the event of every request
|
||||
for r in requests:
|
||||
if r.confirm_received is None:
|
||||
r.confirm_received = Event(loop=self._loop)
|
||||
else:
|
||||
r.confirm_received.clear()
|
||||
|
||||
# Finally send our packed request(s)
|
||||
messages = [TLMessage(self.session, r) for r in requests]
|
||||
self._pending_receive.update({m.msg_id: m for m in messages})
|
||||
|
||||
# Pack everything in the same container if we need to send AckRequests
|
||||
if self._need_confirmation:
|
||||
messages.append(
|
||||
TLMessage(self.session, MsgsAck(list(self._need_confirmation)))
|
||||
)
|
||||
self._need_confirmation.clear()
|
||||
|
||||
if len(messages) == 1:
|
||||
message = messages[0]
|
||||
else:
|
||||
message = TLMessage(self.session, MessageContainer(messages))
|
||||
for m in messages:
|
||||
m.container_msg_id = message.msg_id
|
||||
|
||||
await self._send_message(message)
|
||||
|
||||
|
@ -115,6 +116,7 @@ class MtProtoSender:
|
|||
message, remote_msg_id, remote_seq = self._decode_msg(body)
|
||||
with BinaryReader(message) as reader:
|
||||
await self._process_msg(remote_msg_id, remote_seq, reader, update_state)
|
||||
await self._send_acknowledge(remote_msg_id)
|
||||
|
||||
# endregion
|
||||
|
||||
|
@ -174,7 +176,6 @@ class MtProtoSender:
|
|||
"""
|
||||
|
||||
# TODO Check salt, session_id and sequence_number
|
||||
self._need_confirmation.add(msg_id)
|
||||
|
||||
code = reader.read_int(signed=False)
|
||||
reader.seek(-4)
|
||||
|
@ -210,7 +211,7 @@ class MtProtoSender:
|
|||
if code == MsgsAck.CONSTRUCTOR_ID: # may handle the request we wanted
|
||||
ack = reader.tgread_object()
|
||||
assert isinstance(ack, MsgsAck)
|
||||
# Ignore every ack request *unless* when logging out, when it's
|
||||
# Ignore every ack request *unless* when logging out,
|
||||
# when it seems to only make sense. We also need to set a non-None
|
||||
# result since Telegram doesn't send the response for these.
|
||||
for msg_id in ack.msg_ids:
|
||||
|
@ -259,11 +260,29 @@ class MtProtoSender:
|
|||
if message and isinstance(message.request, t):
|
||||
return self._pending_receive.pop(msg_id).request
|
||||
|
||||
def _pop_requests_of_container(self, container_msg_id):
|
||||
msgs = [msg for msg in self._pending_receive.values() if msg.container_msg_id == container_msg_id]
|
||||
requests = [msg.request for msg in msgs]
|
||||
for msg in msgs:
|
||||
self._pending_receive.pop(msg.msg_id, None)
|
||||
return requests
|
||||
|
||||
def _clear_all_pending(self):
|
||||
for r in self._pending_receive.values():
|
||||
r.request.confirm_received.set()
|
||||
self._pending_receive.clear()
|
||||
|
||||
async def _resend_request(self, msg_id):
|
||||
request = self._pop_request(msg_id)
|
||||
if request:
|
||||
self._logger.debug('requests is about to resend')
|
||||
await self.send(request)
|
||||
return
|
||||
requests = self._pop_requests_of_container(msg_id)
|
||||
if requests:
|
||||
self._logger.debug('container of requests is about to resend')
|
||||
await self.send(*requests)
|
||||
|
||||
async def _handle_pong(self, msg_id, sequence, reader):
|
||||
self._logger.debug('Handling pong')
|
||||
pong = reader.tgread_object()
|
||||
|
@ -305,9 +324,7 @@ class MtProtoSender:
|
|||
)[0]
|
||||
self.session.save()
|
||||
|
||||
request = self._pop_request(bad_salt.bad_msg_id)
|
||||
if request:
|
||||
await self.send(request)
|
||||
await self._resend_request(bad_salt.bad_msg_id)
|
||||
|
||||
return True
|
||||
|
||||
|
@ -323,15 +340,18 @@ class MtProtoSender:
|
|||
self.session.update_time_offset(correct_msg_id=msg_id)
|
||||
self._logger.debug('Read Bad Message error: ' + str(error))
|
||||
self._logger.debug('Attempting to use the correct time offset.')
|
||||
await self._resend_request(bad_msg.bad_msg_id)
|
||||
return True
|
||||
elif bad_msg.error_code == 32:
|
||||
# msg_seqno too low, so just pump it up by some "large" amount
|
||||
# TODO A better fix would be to start with a new fresh session ID
|
||||
self.session._sequence += 64
|
||||
await self._resend_request(bad_msg.bad_msg_id)
|
||||
return True
|
||||
elif bad_msg.error_code == 33:
|
||||
# msg_seqno too high never seems to happen but just in case
|
||||
self.session._sequence -= 16
|
||||
await self._resend_request(bad_msg.bad_msg_id)
|
||||
return True
|
||||
else:
|
||||
raise error
|
||||
|
@ -342,7 +362,6 @@ class MtProtoSender:
|
|||
|
||||
# TODO For now, simply ack msg_new.answer_msg_id
|
||||
# Relevant tdesktop source code: https://goo.gl/VvpCC6
|
||||
await self._send_acknowledge(msg_new.answer_msg_id)
|
||||
return True
|
||||
|
||||
async def _handle_msg_new_detailed_info(self, msg_id, sequence, reader):
|
||||
|
@ -351,7 +370,6 @@ class MtProtoSender:
|
|||
|
||||
# TODO For now, simply ack msg_new.answer_msg_id
|
||||
# Relevant tdesktop source code: https://goo.gl/G7DPsR
|
||||
await self._send_acknowledge(msg_new.answer_msg_id)
|
||||
return True
|
||||
|
||||
async def _handle_new_session_created(self, msg_id, sequence, reader):
|
||||
|
@ -379,9 +397,6 @@ class MtProtoSender:
|
|||
reader.read_int(), reader.tgread_string()
|
||||
)
|
||||
|
||||
# Acknowledge that we received the error
|
||||
await self._send_acknowledge(request_id)
|
||||
|
||||
if request:
|
||||
request.rpc_error = error
|
||||
request.confirm_received.set()
|
||||
|
@ -412,11 +427,6 @@ class MtProtoSender:
|
|||
async def _handle_gzip_packed(self, msg_id, sequence, reader, state):
|
||||
self._logger.debug('Handling gzip packed data')
|
||||
with BinaryReader(GzipPacked.read(reader)) as compressed_reader:
|
||||
# We are reentering process_msg, which seemingly the same msg_id
|
||||
# to the self._need_confirmation set. Remove it from there first
|
||||
# to avoid any future conflicts (i.e. if we "ignore" messages
|
||||
# that we are already aware of, see 1a91c02 and old 63dfb1e)
|
||||
self._need_confirmation -= {msg_id}
|
||||
return await self._process_msg(msg_id, sequence, compressed_reader, state)
|
||||
|
||||
# endregion
|
||||
|
|
|
@ -1,10 +1,10 @@
|
|||
import logging
|
||||
import os
|
||||
import warnings
|
||||
import asyncio
|
||||
from datetime import timedelta, datetime
|
||||
from hashlib import md5
|
||||
from io import BytesIO
|
||||
from time import sleep
|
||||
from asyncio import Lock
|
||||
|
||||
from . import helpers as utils
|
||||
from .crypto import rsa, CdnDecrypter
|
||||
|
@ -17,7 +17,7 @@ from .network import authenticator, MtProtoSender, Connection, ConnectionMode
|
|||
from .tl import TLObject, Session
|
||||
from .tl.all_tlobjects import LAYER
|
||||
from .tl.functions import (
|
||||
InitConnectionRequest, InvokeWithLayerRequest
|
||||
InitConnectionRequest, InvokeWithLayerRequest, PingRequest
|
||||
)
|
||||
from .tl.functions.auth import (
|
||||
ImportAuthorizationRequest, ExportAuthorizationRequest
|
||||
|
@ -67,6 +67,7 @@ class TelegramBareClient:
|
|||
connection_mode=ConnectionMode.TCP_FULL,
|
||||
proxy=None,
|
||||
timeout=timedelta(seconds=5),
|
||||
loop=None,
|
||||
**kwargs):
|
||||
"""Refer to TelegramClient.__init__ for docs on this method"""
|
||||
if not api_id or not api_hash:
|
||||
|
@ -82,6 +83,8 @@ class TelegramBareClient:
|
|||
'The given session must be a str or a Session instance.'
|
||||
)
|
||||
|
||||
self._loop = loop if loop else asyncio.get_event_loop()
|
||||
|
||||
self.session = session
|
||||
self.api_id = int(api_id)
|
||||
self.api_hash = api_hash
|
||||
|
@ -92,12 +95,18 @@ class TelegramBareClient:
|
|||
# that calls .connect(). Every other thread will spawn a new
|
||||
# temporary connection. The connection on this one is always
|
||||
# kept open so Telegram can send us updates.
|
||||
self._sender = MtProtoSender(self.session, Connection(
|
||||
mode=connection_mode, proxy=proxy, timeout=timeout
|
||||
))
|
||||
self._sender = MtProtoSender(
|
||||
self.session,
|
||||
Connection(mode=connection_mode, proxy=proxy, timeout=timeout, loop=self._loop),
|
||||
self._loop
|
||||
)
|
||||
|
||||
self._logger = logging.getLogger(__name__)
|
||||
|
||||
# Two coroutines may be calling reconnect() when the connection is lost,
|
||||
# we only want one to actually perform the reconnection.
|
||||
self._reconnect_lock = Lock(loop=self._loop)
|
||||
|
||||
# Cache "exported" sessions as 'dc_id: Session' not to recreate
|
||||
# them all the time since generating a new key is a relatively
|
||||
# expensive operation.
|
||||
|
@ -105,7 +114,7 @@ class TelegramBareClient:
|
|||
|
||||
# This member will process updates if enabled.
|
||||
# One may change self.updates.enabled at any later point.
|
||||
self.updates = UpdateState(workers=None)
|
||||
self.updates = UpdateState(self._loop)
|
||||
|
||||
# Used on connection - the user may modify these and reconnect
|
||||
kwargs['app_version'] = kwargs.get('app_version', self.__version__)
|
||||
|
@ -129,10 +138,11 @@ class TelegramBareClient:
|
|||
# Uploaded files cache so subsequent calls are instant
|
||||
self._upload_cache = {}
|
||||
|
||||
# Default PingRequest delay
|
||||
self._last_ping = datetime.now()
|
||||
self._ping_delay = timedelta(minutes=1)
|
||||
self._recv_loop = None
|
||||
self._ping_loop = None
|
||||
|
||||
# Default PingRequest delay
|
||||
self._ping_delay = timedelta(minutes=1)
|
||||
|
||||
# endregion
|
||||
|
||||
|
@ -167,6 +177,7 @@ class TelegramBareClient:
|
|||
self.session.auth_key, self.session.time_offset = \
|
||||
await authenticator.do_authentication(self._sender.connection)
|
||||
except BrokenAuthKeyError:
|
||||
self._user_connected = False
|
||||
return False
|
||||
|
||||
self.session.layer = LAYER
|
||||
|
@ -213,7 +224,7 @@ class TelegramBareClient:
|
|||
# This is fine, probably layer migration
|
||||
self._logger.debug('Found invalid item, probably migrating', e)
|
||||
self.disconnect()
|
||||
return self.connect(
|
||||
return await self.connect(
|
||||
_exported_auth=_exported_auth,
|
||||
_sync_updates=_sync_updates,
|
||||
_cdn=_cdn
|
||||
|
@ -263,7 +274,17 @@ class TelegramBareClient:
|
|||
"""
|
||||
if new_dc is None:
|
||||
# Assume we are disconnected due to some error, so connect again
|
||||
return await self.connect()
|
||||
try:
|
||||
await self._reconnect_lock.acquire()
|
||||
# Another thread may have connected again, so check that first
|
||||
if self.is_connected():
|
||||
return True
|
||||
|
||||
return await self.connect()
|
||||
except ConnectionResetError:
|
||||
return False
|
||||
finally:
|
||||
self._reconnect_lock.release()
|
||||
else:
|
||||
self.disconnect()
|
||||
self.session.auth_key = None # Force creating new auth_key
|
||||
|
@ -339,7 +360,8 @@ class TelegramBareClient:
|
|||
client = TelegramBareClient(
|
||||
session, self.api_id, self.api_hash,
|
||||
proxy=self._sender.connection.conn.proxy,
|
||||
timeout=self._sender.connection.get_timeout()
|
||||
timeout=self._sender.connection.get_timeout(),
|
||||
loop=self._loop
|
||||
)
|
||||
await client.connect(_exported_auth=export_auth, _sync_updates=False)
|
||||
client._authorized = True # We exported the auth, so we got auth
|
||||
|
@ -358,7 +380,8 @@ class TelegramBareClient:
|
|||
client = TelegramBareClient(
|
||||
session, self.api_id, self.api_hash,
|
||||
proxy=self._sender.connection.conn.proxy,
|
||||
timeout=self._sender.connection.get_timeout()
|
||||
timeout=self._sender.connection.get_timeout(),
|
||||
loop=self._loop
|
||||
)
|
||||
|
||||
# This will make use of the new RSA keys for this specific CDN.
|
||||
|
@ -383,53 +406,51 @@ class TelegramBareClient:
|
|||
x.content_related for x in requests):
|
||||
raise ValueError('You can only invoke requests, not types!')
|
||||
|
||||
# TODO Determine the sender to be used (main or a new connection)
|
||||
sender = self._sender # .clone(), .connect()
|
||||
# We're on the same connection so no need to pass update_state=None
|
||||
# to avoid getting messages that we haven't acknowledged yet.
|
||||
# We should call receive from this thread if there's no background
|
||||
# thread reading or if the server disconnected us and we're trying
|
||||
# to reconnect. This is because the read thread may either be
|
||||
# locked also trying to reconnect or we may be said thread already.
|
||||
call_receive = self._recv_loop is None
|
||||
|
||||
try:
|
||||
for _ in range(retries):
|
||||
result = await self._invoke(sender, *requests)
|
||||
if result is not None:
|
||||
return result
|
||||
for retry in range(retries):
|
||||
result = await self._invoke(call_receive, retry, *requests)
|
||||
if result is not None:
|
||||
return result
|
||||
|
||||
raise ValueError('Number of retries reached 0.')
|
||||
finally:
|
||||
if sender != self._sender:
|
||||
sender.disconnect() # Close temporary connections
|
||||
raise ValueError('Number of retries reached 0.')
|
||||
|
||||
# Let people use client.invoke(SomeRequest()) instead client(...)
|
||||
invoke = __call__
|
||||
|
||||
async def _invoke(self, sender, *requests):
|
||||
async def _invoke(self, call_receive, retry, *requests):
|
||||
try:
|
||||
# Ensure that we start with no previous errors (i.e. resending)
|
||||
for x in requests:
|
||||
x.confirm_received.clear()
|
||||
x.rpc_error = None
|
||||
|
||||
await sender.send(*requests)
|
||||
while not all(x.confirm_received.is_set() for x in requests):
|
||||
await sender.receive(update_state=self.updates)
|
||||
await self._sender.send(*requests)
|
||||
|
||||
except TimeoutError:
|
||||
pass # We will just retry
|
||||
if not call_receive:
|
||||
await asyncio.wait(
|
||||
list(map(lambda x: x.confirm_received.wait(), requests)),
|
||||
timeout=self._sender.connection.get_timeout(),
|
||||
loop=self._loop
|
||||
)
|
||||
else:
|
||||
while not all(x.confirm_received.is_set() for x in requests):
|
||||
await self._sender.receive(update_state=self.updates)
|
||||
|
||||
except ConnectionResetError:
|
||||
if not self._user_connected:
|
||||
# Only attempt reconnecting if we're authorized
|
||||
if not self._user_connected or self._reconnect_lock.locked():
|
||||
# Only attempt reconnecting if the user called connect and not
|
||||
# reconnecting already.
|
||||
raise
|
||||
|
||||
self._logger.debug('Server disconnected us. Reconnecting and '
|
||||
'resending request...')
|
||||
|
||||
if sender != self._sender:
|
||||
# TODO Try reconnecting forever too?
|
||||
await sender.connect()
|
||||
else:
|
||||
while self._user_connected and not await self._reconnect():
|
||||
sleep(0.1) # Retry forever until we can send the request
|
||||
'resending request... (%d)' % retry)
|
||||
await self._reconnect()
|
||||
if not self._sender.is_connected():
|
||||
await asyncio.sleep(retry + 1, loop=self._loop)
|
||||
return None
|
||||
|
||||
try:
|
||||
|
@ -453,7 +474,7 @@ class TelegramBareClient:
|
|||
)
|
||||
|
||||
await self._reconnect(new_dc=e.new_dc)
|
||||
return await self._invoke(sender, *requests)
|
||||
return None
|
||||
|
||||
except ServerError as e:
|
||||
# Telegram is having some issues, just retry
|
||||
|
@ -468,7 +489,8 @@ class TelegramBareClient:
|
|||
self._logger.debug(
|
||||
'Sleep of %d seconds below threshold, sleeping' % e.seconds
|
||||
)
|
||||
sleep(e.seconds)
|
||||
await asyncio.sleep(e.seconds, loop=self._loop)
|
||||
return None
|
||||
|
||||
# Some really basic functionality
|
||||
|
||||
|
@ -671,16 +693,9 @@ class TelegramBareClient:
|
|||
"""
|
||||
self.updates.process(await self(GetStateRequest()))
|
||||
|
||||
def add_update_handler(self, handler):
|
||||
async def add_update_handler(self, handler):
|
||||
"""Adds an update handler (a function which takes a TLObject,
|
||||
an update, as its parameter) and listens for updates"""
|
||||
if self.updates.workers is None:
|
||||
warnings.warn(
|
||||
"You have not setup any workers, so you won't receive updates."
|
||||
" Pass update_workers=4 when creating the TelegramClient,"
|
||||
" or set client.self.updates.workers = 4"
|
||||
)
|
||||
|
||||
self.updates.handlers.append(handler)
|
||||
|
||||
def remove_update_handler(self, handler):
|
||||
|
@ -695,6 +710,60 @@ class TelegramBareClient:
|
|||
|
||||
def _set_connected_and_authorized(self):
|
||||
self._authorized = True
|
||||
# TODO self.updates.setup_workers()
|
||||
if self._recv_loop is None:
|
||||
self._recv_loop = asyncio.ensure_future(self._recv_loop_impl(), loop=self._loop)
|
||||
if self._ping_loop is None:
|
||||
self._ping_loop = asyncio.ensure_future(self._ping_loop_impl(), loop=self._loop)
|
||||
|
||||
async def _ping_loop_impl(self):
|
||||
while self._user_connected:
|
||||
await self(PingRequest(int.from_bytes(os.urandom(8), 'big', signed=True)))
|
||||
await asyncio.sleep(self._ping_delay.seconds, loop=self._loop)
|
||||
self._ping_loop = None
|
||||
|
||||
async def _recv_loop_impl(self):
|
||||
need_reconnect = False
|
||||
while self._user_connected:
|
||||
try:
|
||||
if need_reconnect:
|
||||
need_reconnect = False
|
||||
while self._user_connected and not await self._reconnect():
|
||||
await asyncio.sleep(0.1, loop=self._loop) # Retry forever, this is instant messaging
|
||||
|
||||
await self._sender.receive(update_state=self.updates)
|
||||
except TimeoutError:
|
||||
# No problem.
|
||||
pass
|
||||
except ConnectionError as error:
|
||||
self._logger.debug(error)
|
||||
need_reconnect = True
|
||||
await asyncio.sleep(1, loop=self._loop)
|
||||
except Exception as error:
|
||||
# Unknown exception, pass it to the main thread
|
||||
self._logger.debug(
|
||||
'[ERROR] Unknown error on the read loop, please report',
|
||||
error
|
||||
)
|
||||
|
||||
try:
|
||||
import socks
|
||||
if isinstance(error, (
|
||||
socks.GeneralProxyError, socks.ProxyConnectionError
|
||||
)):
|
||||
# This is a known error, and it's not related to
|
||||
# Telegram but rather to the proxy. Disconnect and
|
||||
# hand it over to the main thread.
|
||||
self._background_error = error
|
||||
self.disconnect()
|
||||
break
|
||||
except ImportError:
|
||||
"Not using PySocks, so it can't be a socket error"
|
||||
|
||||
# If something strange happens we don't want to enter an
|
||||
# infinite loop where all we do is raise an exception, so
|
||||
# add a little sleep to avoid the CPU usage going mad.
|
||||
await asyncio.sleep(0.1, loop=self._loop)
|
||||
break
|
||||
self._recv_loop = None
|
||||
|
||||
# endregion
|
||||
|
|
|
@ -61,6 +61,7 @@ class TelegramClient(TelegramBareClient):
|
|||
connection_mode=ConnectionMode.TCP_FULL,
|
||||
proxy=None,
|
||||
timeout=timedelta(seconds=5),
|
||||
loop=None,
|
||||
**kwargs):
|
||||
"""Initializes the Telegram client with the specified API ID and Hash.
|
||||
|
||||
|
@ -87,6 +88,7 @@ class TelegramClient(TelegramBareClient):
|
|||
connection_mode=connection_mode,
|
||||
proxy=proxy,
|
||||
timeout=timeout,
|
||||
loop=loop,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
|
@ -202,7 +204,7 @@ class TelegramClient(TelegramBareClient):
|
|||
"""Gets "me" (the self user) which is currently authenticated,
|
||||
or None if the request fails (hence, not authenticated)."""
|
||||
try:
|
||||
return await self(GetUsersRequest([InputUserSelf()]))[0]
|
||||
return (await self(GetUsersRequest([InputUserSelf()])))[0]
|
||||
except UnauthorizedError:
|
||||
return None
|
||||
|
||||
|
@ -313,6 +315,7 @@ class TelegramClient(TelegramBareClient):
|
|||
reply_to_msg_id=self._get_reply_to(reply_to)
|
||||
)
|
||||
result = await self(request)
|
||||
|
||||
if isinstance(result, UpdateShortSentMessage):
|
||||
return Message(
|
||||
id=result.id,
|
||||
|
|
|
@ -11,6 +11,15 @@ class MessageContainer(TLObject):
|
|||
self.content_related = False
|
||||
self.messages = messages
|
||||
|
||||
def to_dict(self, recursive=True):
|
||||
return {
|
||||
'content_related': self.content_related,
|
||||
'messages':
|
||||
([] if self.messages is None else [
|
||||
None if x is None else x.to_dict() for x in self.messages
|
||||
]) if recursive else self.messages,
|
||||
}
|
||||
|
||||
def __bytes__(self):
|
||||
return struct.pack(
|
||||
'<Ii', MessageContainer.CONSTRUCTOR_ID, len(self.messages)
|
||||
|
@ -25,3 +34,9 @@ class MessageContainer(TLObject):
|
|||
inner_sequence = reader.read_int()
|
||||
inner_length = reader.read_int()
|
||||
yield inner_msg_id, inner_sequence, inner_length
|
||||
|
||||
def __str__(self):
|
||||
return TLObject.pretty_format(self)
|
||||
|
||||
def stringify(self):
|
||||
return TLObject.pretty_format(self, indent=0)
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
import struct
|
||||
import logging
|
||||
|
||||
from . import TLObject, GzipPacked
|
||||
|
||||
|
@ -11,7 +12,23 @@ class TLMessage(TLObject):
|
|||
self.msg_id = session.get_new_msg_id()
|
||||
self.seq_no = session.generate_sequence(request.content_related)
|
||||
self.request = request
|
||||
self.container_msg_id = None
|
||||
logging.getLogger(__name__).debug(self)
|
||||
|
||||
def to_dict(self, recursive=True):
|
||||
return {
|
||||
'msg_id': self.msg_id,
|
||||
'seq_no': self.seq_no,
|
||||
'request': self.request,
|
||||
'container_msg_id': self.container_msg_id,
|
||||
}
|
||||
|
||||
def __bytes__(self):
|
||||
body = GzipPacked.gzip_if_smaller(self.request)
|
||||
return struct.pack('<qii', self.msg_id, self.seq_no, len(body)) + body
|
||||
|
||||
def __str__(self):
|
||||
return TLObject.pretty_format(self)
|
||||
|
||||
def stringify(self):
|
||||
return TLObject.pretty_format(self, indent=0)
|
||||
|
|
|
@ -1,12 +1,10 @@
|
|||
from datetime import datetime
|
||||
from threading import Event
|
||||
|
||||
|
||||
class TLObject:
|
||||
def __init__(self):
|
||||
self.request_msg_id = 0 # Long
|
||||
|
||||
self.confirm_received = Event()
|
||||
self.confirm_received = None
|
||||
self.rpc_error = None
|
||||
|
||||
# These should be overrode
|
||||
|
|
|
@ -1,8 +1,8 @@
|
|||
import logging
|
||||
import pickle
|
||||
import asyncio
|
||||
from collections import deque
|
||||
from datetime import datetime
|
||||
from threading import RLock, Event, Thread
|
||||
|
||||
from .tl import types as tl
|
||||
|
||||
|
@ -13,177 +13,72 @@ class UpdateState:
|
|||
"""
|
||||
WORKER_POLL_TIMEOUT = 5.0 # Avoid waiting forever on the workers
|
||||
|
||||
def __init__(self, workers=None):
|
||||
"""
|
||||
:param workers: This integer parameter has three possible cases:
|
||||
workers is None: Updates will *not* be stored on self.
|
||||
workers = 0: Another thread is responsible for calling self.poll()
|
||||
workers > 0: 'workers' background threads will be spawned, any
|
||||
any of them will invoke all the self.handlers.
|
||||
"""
|
||||
self._workers = workers
|
||||
self._worker_threads = []
|
||||
|
||||
def __init__(self, loop=None):
|
||||
self.handlers = []
|
||||
self._updates_lock = RLock()
|
||||
self._updates_available = Event()
|
||||
self._updates = deque()
|
||||
self._latest_updates = deque(maxlen=10)
|
||||
self._loop = loop if loop else asyncio.get_event_loop()
|
||||
|
||||
self._logger = logging.getLogger(__name__)
|
||||
|
||||
# https://core.telegram.org/api/updates
|
||||
self._state = tl.updates.State(0, 0, datetime.now(), 0, 0)
|
||||
|
||||
def can_poll(self):
|
||||
"""Returns True if a call to .poll() won't lock"""
|
||||
return self._updates_available.is_set()
|
||||
|
||||
def poll(self, timeout=None):
|
||||
"""Polls an update or blocks until an update object is available.
|
||||
If 'timeout is not None', it should be a floating point value,
|
||||
and the method will 'return None' if waiting times out.
|
||||
"""
|
||||
if not self._updates_available.wait(timeout=timeout):
|
||||
return
|
||||
|
||||
with self._updates_lock:
|
||||
if not self._updates_available.is_set():
|
||||
return
|
||||
|
||||
update = self._updates.popleft()
|
||||
if not self._updates:
|
||||
self._updates_available.clear()
|
||||
|
||||
if isinstance(update, Exception):
|
||||
raise update # Some error was set through (surely StopIteration)
|
||||
|
||||
return update
|
||||
|
||||
def get_workers(self):
|
||||
return self._workers
|
||||
|
||||
def set_workers(self, n):
|
||||
"""Changes the number of workers running.
|
||||
If 'n is None', clears all pending updates from memory.
|
||||
"""
|
||||
self.stop_workers()
|
||||
self._workers = n
|
||||
if n is None:
|
||||
self._updates.clear()
|
||||
else:
|
||||
self.setup_workers()
|
||||
|
||||
workers = property(fget=get_workers, fset=set_workers)
|
||||
|
||||
def stop_workers(self):
|
||||
"""Raises "StopIterationException" on the worker threads to stop them,
|
||||
and also clears all of them off the list
|
||||
"""
|
||||
if self._workers:
|
||||
with self._updates_lock:
|
||||
# Insert at the beginning so the very next poll causes an error
|
||||
# on all the worker threads
|
||||
# TODO Should this reset the pts and such?
|
||||
for _ in range(self._workers):
|
||||
self._updates.appendleft(StopIteration())
|
||||
self._updates_available.set()
|
||||
|
||||
for t in self._worker_threads:
|
||||
t.join()
|
||||
|
||||
self._worker_threads.clear()
|
||||
|
||||
def setup_workers(self):
|
||||
if self._worker_threads or not self._workers:
|
||||
# There already are workers, or workers is None or 0. Do nothing.
|
||||
return
|
||||
|
||||
for i in range(self._workers):
|
||||
thread = Thread(
|
||||
target=UpdateState._worker_loop,
|
||||
name='UpdateWorker{}'.format(i),
|
||||
daemon=True,
|
||||
args=(self, i)
|
||||
)
|
||||
self._worker_threads.append(thread)
|
||||
thread.start()
|
||||
|
||||
def _worker_loop(self, wid):
|
||||
while True:
|
||||
try:
|
||||
update = self.poll(timeout=UpdateState.WORKER_POLL_TIMEOUT)
|
||||
# TODO Maybe people can add different handlers per update type
|
||||
if update:
|
||||
for handler in self.handlers:
|
||||
handler(update)
|
||||
except StopIteration:
|
||||
break
|
||||
except Exception as e:
|
||||
# We don't want to crash a worker thread due to any reason
|
||||
self._logger.debug(
|
||||
'[ERROR] Unhandled exception on worker {}'.format(wid), e
|
||||
)
|
||||
def handle_update(self, update):
|
||||
for handler in self.handlers:
|
||||
asyncio.ensure_future(handler(update), loop=self._loop)
|
||||
|
||||
def process(self, update):
|
||||
"""Processes an update object. This method is normally called by
|
||||
the library itself.
|
||||
"""
|
||||
if self._workers is None:
|
||||
return # No processing needs to be done if nobody's working
|
||||
if isinstance(update, tl.updates.State):
|
||||
self._state = update
|
||||
return # Nothing else to be done
|
||||
|
||||
with self._updates_lock:
|
||||
if isinstance(update, tl.updates.State):
|
||||
self._state = update
|
||||
return # Nothing else to be done
|
||||
pts = getattr(update, 'pts', self._state.pts)
|
||||
if hasattr(update, 'pts') and pts <= self._state.pts:
|
||||
return # We already handled this update
|
||||
|
||||
pts = getattr(update, 'pts', self._state.pts)
|
||||
if hasattr(update, 'pts') and pts <= self._state.pts:
|
||||
return # We already handled this update
|
||||
self._state.pts = pts
|
||||
|
||||
self._state.pts = pts
|
||||
# TODO There must be a better way to handle updates rather than
|
||||
# keeping a queue with the latest updates only, and handling
|
||||
# the 'pts' correctly should be enough. However some updates
|
||||
# like UpdateUserStatus (even inside UpdateShort) will be called
|
||||
# repeatedly very often if invoking anything inside an update
|
||||
# handler. TODO Figure out why.
|
||||
"""
|
||||
client = TelegramClient('anon', api_id, api_hash, update_workers=1)
|
||||
client.connect()
|
||||
def handle(u):
|
||||
client.get_me()
|
||||
client.add_update_handler(handle)
|
||||
input('Enter to exit.')
|
||||
"""
|
||||
data = pickle.dumps(update.to_dict())
|
||||
if data in self._latest_updates:
|
||||
return # Duplicated too
|
||||
|
||||
# TODO There must be a better way to handle updates rather than
|
||||
# keeping a queue with the latest updates only, and handling
|
||||
# the 'pts' correctly should be enough. However some updates
|
||||
# like UpdateUserStatus (even inside UpdateShort) will be called
|
||||
# repeatedly very often if invoking anything inside an update
|
||||
# handler. TODO Figure out why.
|
||||
"""
|
||||
client = TelegramClient('anon', api_id, api_hash, update_workers=1)
|
||||
client.connect()
|
||||
def handle(u):
|
||||
client.get_me()
|
||||
client.add_update_handler(handle)
|
||||
input('Enter to exit.')
|
||||
"""
|
||||
data = pickle.dumps(update.to_dict())
|
||||
if data in self._latest_updates:
|
||||
return # Duplicated too
|
||||
self._latest_updates.append(data)
|
||||
|
||||
self._latest_updates.append(data)
|
||||
if type(update).SUBCLASS_OF_ID == 0x8af52aac: # crc32(b'Updates')
|
||||
# Expand "Updates" into "Update", and pass these to callbacks.
|
||||
# Since .users and .chats have already been processed, we
|
||||
# don't need to care about those either.
|
||||
if isinstance(update, tl.UpdateShort):
|
||||
self.handle_update(update.update)
|
||||
|
||||
if type(update).SUBCLASS_OF_ID == 0x8af52aac: # crc32(b'Updates')
|
||||
# Expand "Updates" into "Update", and pass these to callbacks.
|
||||
# Since .users and .chats have already been processed, we
|
||||
# don't need to care about those either.
|
||||
if isinstance(update, tl.UpdateShort):
|
||||
self._updates.append(update.update)
|
||||
self._updates_available.set()
|
||||
elif isinstance(update, (tl.Updates, tl.UpdatesCombined)):
|
||||
for upd in update.updates:
|
||||
self.handle_update(upd)
|
||||
|
||||
elif isinstance(update, (tl.Updates, tl.UpdatesCombined)):
|
||||
self._updates.extend(update.updates)
|
||||
self._updates_available.set()
|
||||
elif not isinstance(update, tl.UpdatesTooLong):
|
||||
# TODO Handle "Updates too long"
|
||||
self.handle_update(update)
|
||||
|
||||
elif not isinstance(update, tl.UpdatesTooLong):
|
||||
# TODO Handle "Updates too long"
|
||||
self._updates.append(update)
|
||||
self._updates_available.set()
|
||||
|
||||
elif type(update).SUBCLASS_OF_ID == 0x9f89304e: # crc32(b'Update')
|
||||
self._updates.append(update)
|
||||
self._updates_available.set()
|
||||
else:
|
||||
self._logger.debug('Ignoring "update" of type {}'.format(
|
||||
type(update).__name__)
|
||||
)
|
||||
elif type(update).SUBCLASS_OF_ID == 0x9f89304e: # crc32(b'Update')
|
||||
self.handle_update(update)
|
||||
else:
|
||||
self._logger.debug('Ignoring "update" of type {}'.format(
|
||||
type(update).__name__)
|
||||
)
|
||||
|
|
Loading…
Reference in New Issue
Block a user