Merge pull request #370 from andr-04/asyncio

Made update system for asyncio functional
This commit is contained in:
Lonami 2017-10-28 11:07:41 +02:00 committed by GitHub
commit 6dc0ee9d6c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 277 additions and 256 deletions

View File

@ -5,13 +5,18 @@ import socket
from datetime import timedelta
from io import BytesIO, BufferedWriter
loop = asyncio.get_event_loop()
MAX_TIMEOUT = 15 # in seconds
CONN_RESET_ERRNOS = {
errno.EBADF, errno.ENOTSOCK, errno.ENETUNREACH,
errno.EINVAL, errno.ENOTCONN
}
class TcpClient:
def __init__(self, proxy=None, timeout=timedelta(seconds=5)):
def __init__(self, proxy=None, timeout=timedelta(seconds=5), loop=None):
self.proxy = proxy
self._socket = None
self._loop = loop if loop else asyncio.get_event_loop()
if isinstance(timeout, timedelta):
self.timeout = timeout.seconds
@ -31,7 +36,7 @@ class TcpClient:
else: # tuple, list, etc.
self._socket.set_proxy(*self.proxy)
self._socket.settimeout(self.timeout)
self._socket.setblocking(False)
async def connect(self, ip, port):
"""Connects to the specified IP and port number.
@ -42,20 +47,27 @@ class TcpClient:
else:
mode, address = socket.AF_INET, (ip, port)
timeout = 1
while True:
try:
while not self._socket:
if not self._socket:
self._recreate_socket(mode)
await loop.sock_connect(self._socket, address)
await self._loop.sock_connect(self._socket, address)
break # Successful connection, stop retrying to connect
except ConnectionError:
self._socket = None
await asyncio.sleep(timeout)
timeout = min(timeout * 2, MAX_TIMEOUT)
except OSError as e:
# There are some errors that we know how to handle, and
# the loop will allow us to retry
if e.errno == errno.EBADF:
if e.errno in (errno.EBADF, errno.ENOTSOCK, errno.EINVAL):
# Bad file descriptor, i.e. socket was closed, set it
# to none to recreate it on the next iteration
self._socket = None
await asyncio.sleep(timeout)
timeout = min(timeout * 2, MAX_TIMEOUT)
else:
raise
@ -81,13 +93,14 @@ class TcpClient:
raise ConnectionResetError()
try:
await loop.sock_sendall(self._socket, data)
except socket.timeout as e:
await asyncio.wait_for(self._loop.sock_sendall(self._socket, data),
timeout=self.timeout, loop=self._loop)
except asyncio.TimeoutError as e:
raise TimeoutError() from e
except BrokenPipeError:
self._raise_connection_reset()
except OSError as e:
if e.errno == errno.EBADF:
if e.errno in CONN_RESET_ERRNOS:
self._raise_connection_reset()
else:
raise
@ -104,11 +117,12 @@ class TcpClient:
bytes_left = size
while bytes_left != 0:
try:
partial = await loop.sock_recv(self._socket, bytes_left)
except socket.timeout as e:
partial = await asyncio.wait_for(self._loop.sock_recv(self._socket, bytes_left),
timeout=self.timeout, loop=self._loop)
except asyncio.TimeoutError as e:
raise TimeoutError() from e
except OSError as e:
if e.errno == errno.EBADF or e.errno == errno.ENOTSOCK:
if e.errno in CONN_RESET_ERRNOS:
self._raise_connection_reset()
else:
raise

View File

@ -43,13 +43,13 @@ class Connection:
"""
def __init__(self, mode=ConnectionMode.TCP_FULL,
proxy=None, timeout=timedelta(seconds=5)):
proxy=None, timeout=timedelta(seconds=5), loop=None):
self._mode = mode
self._send_counter = 0
self._aes_encrypt, self._aes_decrypt = None, None
# TODO Rename "TcpClient" as some sort of generic socket?
self.conn = TcpClient(proxy=proxy, timeout=timeout)
self.conn = TcpClient(proxy=proxy, timeout=timeout, loop=loop)
# Sending messages
if mode == ConnectionMode.TCP_FULL:
@ -206,7 +206,7 @@ class Connection:
return await self.conn.read(length)
async def _read_obfuscated(self, length):
return await self._aes_decrypt.encrypt(self.conn.read(length))
return self._aes_decrypt.encrypt(await self.conn.read(length))
# endregion

View File

@ -1,6 +1,8 @@
import gzip
import logging
import struct
import asyncio
from asyncio import Event
from .. import helpers as utils
from ..crypto import AES
@ -30,17 +32,15 @@ class MtProtoSender:
in parallel, so thread-safety (hence locking) isn't needed.
"""
def __init__(self, session, connection):
def __init__(self, session, connection, loop=None):
"""Creates a new MtProtoSender configured to send messages through
'connection' and using the parameters from 'session'.
"""
self.session = session
self.connection = connection
self._loop = loop if loop else asyncio.get_event_loop()
self._logger = logging.getLogger(__name__)
# Message IDs that need confirmation
self._need_confirmation = set()
# Requests (as msg_id: Message) sent waiting to be received
self._pending_receive = {}
@ -54,12 +54,11 @@ class MtProtoSender:
def disconnect(self):
"""Disconnects from the server"""
self.connection.close()
self._need_confirmation.clear()
self._clear_all_pending()
def clone(self):
"""Creates a copy of this MtProtoSender as a new connection"""
return MtProtoSender(self.session, self.connection.clone())
return MtProtoSender(self.session, self.connection.clone(), self._loop)
# region Send and receive
@ -67,21 +66,23 @@ class MtProtoSender:
"""Sends the specified MTProtoRequest, previously sending any message
which needed confirmation."""
# Prepare the event of every request
for r in requests:
if r.confirm_received is None:
r.confirm_received = Event(loop=self._loop)
else:
r.confirm_received.clear()
# Finally send our packed request(s)
messages = [TLMessage(self.session, r) for r in requests]
self._pending_receive.update({m.msg_id: m for m in messages})
# Pack everything in the same container if we need to send AckRequests
if self._need_confirmation:
messages.append(
TLMessage(self.session, MsgsAck(list(self._need_confirmation)))
)
self._need_confirmation.clear()
if len(messages) == 1:
message = messages[0]
else:
message = TLMessage(self.session, MessageContainer(messages))
for m in messages:
m.container_msg_id = message.msg_id
await self._send_message(message)
@ -115,6 +116,7 @@ class MtProtoSender:
message, remote_msg_id, remote_seq = self._decode_msg(body)
with BinaryReader(message) as reader:
await self._process_msg(remote_msg_id, remote_seq, reader, update_state)
await self._send_acknowledge(remote_msg_id)
# endregion
@ -174,7 +176,6 @@ class MtProtoSender:
"""
# TODO Check salt, session_id and sequence_number
self._need_confirmation.add(msg_id)
code = reader.read_int(signed=False)
reader.seek(-4)
@ -210,7 +211,7 @@ class MtProtoSender:
if code == MsgsAck.CONSTRUCTOR_ID: # may handle the request we wanted
ack = reader.tgread_object()
assert isinstance(ack, MsgsAck)
# Ignore every ack request *unless* when logging out, when it's
# Ignore every ack request *unless* when logging out,
# when it seems to only make sense. We also need to set a non-None
# result since Telegram doesn't send the response for these.
for msg_id in ack.msg_ids:
@ -259,11 +260,29 @@ class MtProtoSender:
if message and isinstance(message.request, t):
return self._pending_receive.pop(msg_id).request
def _pop_requests_of_container(self, container_msg_id):
msgs = [msg for msg in self._pending_receive.values() if msg.container_msg_id == container_msg_id]
requests = [msg.request for msg in msgs]
for msg in msgs:
self._pending_receive.pop(msg.msg_id, None)
return requests
def _clear_all_pending(self):
for r in self._pending_receive.values():
r.request.confirm_received.set()
self._pending_receive.clear()
async def _resend_request(self, msg_id):
request = self._pop_request(msg_id)
if request:
self._logger.debug('requests is about to resend')
await self.send(request)
return
requests = self._pop_requests_of_container(msg_id)
if requests:
self._logger.debug('container of requests is about to resend')
await self.send(*requests)
async def _handle_pong(self, msg_id, sequence, reader):
self._logger.debug('Handling pong')
pong = reader.tgread_object()
@ -305,9 +324,7 @@ class MtProtoSender:
)[0]
self.session.save()
request = self._pop_request(bad_salt.bad_msg_id)
if request:
await self.send(request)
await self._resend_request(bad_salt.bad_msg_id)
return True
@ -323,15 +340,18 @@ class MtProtoSender:
self.session.update_time_offset(correct_msg_id=msg_id)
self._logger.debug('Read Bad Message error: ' + str(error))
self._logger.debug('Attempting to use the correct time offset.')
await self._resend_request(bad_msg.bad_msg_id)
return True
elif bad_msg.error_code == 32:
# msg_seqno too low, so just pump it up by some "large" amount
# TODO A better fix would be to start with a new fresh session ID
self.session._sequence += 64
await self._resend_request(bad_msg.bad_msg_id)
return True
elif bad_msg.error_code == 33:
# msg_seqno too high never seems to happen but just in case
self.session._sequence -= 16
await self._resend_request(bad_msg.bad_msg_id)
return True
else:
raise error
@ -342,7 +362,6 @@ class MtProtoSender:
# TODO For now, simply ack msg_new.answer_msg_id
# Relevant tdesktop source code: https://goo.gl/VvpCC6
await self._send_acknowledge(msg_new.answer_msg_id)
return True
async def _handle_msg_new_detailed_info(self, msg_id, sequence, reader):
@ -351,7 +370,6 @@ class MtProtoSender:
# TODO For now, simply ack msg_new.answer_msg_id
# Relevant tdesktop source code: https://goo.gl/G7DPsR
await self._send_acknowledge(msg_new.answer_msg_id)
return True
async def _handle_new_session_created(self, msg_id, sequence, reader):
@ -379,9 +397,6 @@ class MtProtoSender:
reader.read_int(), reader.tgread_string()
)
# Acknowledge that we received the error
await self._send_acknowledge(request_id)
if request:
request.rpc_error = error
request.confirm_received.set()
@ -412,11 +427,6 @@ class MtProtoSender:
async def _handle_gzip_packed(self, msg_id, sequence, reader, state):
self._logger.debug('Handling gzip packed data')
with BinaryReader(GzipPacked.read(reader)) as compressed_reader:
# We are reentering process_msg, which seemingly the same msg_id
# to the self._need_confirmation set. Remove it from there first
# to avoid any future conflicts (i.e. if we "ignore" messages
# that we are already aware of, see 1a91c02 and old 63dfb1e)
self._need_confirmation -= {msg_id}
return await self._process_msg(msg_id, sequence, compressed_reader, state)
# endregion

View File

@ -1,10 +1,10 @@
import logging
import os
import warnings
import asyncio
from datetime import timedelta, datetime
from hashlib import md5
from io import BytesIO
from time import sleep
from asyncio import Lock
from . import helpers as utils
from .crypto import rsa, CdnDecrypter
@ -17,7 +17,7 @@ from .network import authenticator, MtProtoSender, Connection, ConnectionMode
from .tl import TLObject, Session
from .tl.all_tlobjects import LAYER
from .tl.functions import (
InitConnectionRequest, InvokeWithLayerRequest
InitConnectionRequest, InvokeWithLayerRequest, PingRequest
)
from .tl.functions.auth import (
ImportAuthorizationRequest, ExportAuthorizationRequest
@ -67,6 +67,7 @@ class TelegramBareClient:
connection_mode=ConnectionMode.TCP_FULL,
proxy=None,
timeout=timedelta(seconds=5),
loop=None,
**kwargs):
"""Refer to TelegramClient.__init__ for docs on this method"""
if not api_id or not api_hash:
@ -82,6 +83,8 @@ class TelegramBareClient:
'The given session must be a str or a Session instance.'
)
self._loop = loop if loop else asyncio.get_event_loop()
self.session = session
self.api_id = int(api_id)
self.api_hash = api_hash
@ -92,12 +95,18 @@ class TelegramBareClient:
# that calls .connect(). Every other thread will spawn a new
# temporary connection. The connection on this one is always
# kept open so Telegram can send us updates.
self._sender = MtProtoSender(self.session, Connection(
mode=connection_mode, proxy=proxy, timeout=timeout
))
self._sender = MtProtoSender(
self.session,
Connection(mode=connection_mode, proxy=proxy, timeout=timeout, loop=self._loop),
self._loop
)
self._logger = logging.getLogger(__name__)
# Two coroutines may be calling reconnect() when the connection is lost,
# we only want one to actually perform the reconnection.
self._reconnect_lock = Lock(loop=self._loop)
# Cache "exported" sessions as 'dc_id: Session' not to recreate
# them all the time since generating a new key is a relatively
# expensive operation.
@ -105,7 +114,7 @@ class TelegramBareClient:
# This member will process updates if enabled.
# One may change self.updates.enabled at any later point.
self.updates = UpdateState(workers=None)
self.updates = UpdateState(self._loop)
# Used on connection - the user may modify these and reconnect
kwargs['app_version'] = kwargs.get('app_version', self.__version__)
@ -129,10 +138,11 @@ class TelegramBareClient:
# Uploaded files cache so subsequent calls are instant
self._upload_cache = {}
# Default PingRequest delay
self._last_ping = datetime.now()
self._ping_delay = timedelta(minutes=1)
self._recv_loop = None
self._ping_loop = None
# Default PingRequest delay
self._ping_delay = timedelta(minutes=1)
# endregion
@ -167,6 +177,7 @@ class TelegramBareClient:
self.session.auth_key, self.session.time_offset = \
await authenticator.do_authentication(self._sender.connection)
except BrokenAuthKeyError:
self._user_connected = False
return False
self.session.layer = LAYER
@ -213,7 +224,7 @@ class TelegramBareClient:
# This is fine, probably layer migration
self._logger.debug('Found invalid item, probably migrating', e)
self.disconnect()
return self.connect(
return await self.connect(
_exported_auth=_exported_auth,
_sync_updates=_sync_updates,
_cdn=_cdn
@ -263,7 +274,17 @@ class TelegramBareClient:
"""
if new_dc is None:
# Assume we are disconnected due to some error, so connect again
return await self.connect()
try:
await self._reconnect_lock.acquire()
# Another thread may have connected again, so check that first
if self.is_connected():
return True
return await self.connect()
except ConnectionResetError:
return False
finally:
self._reconnect_lock.release()
else:
self.disconnect()
self.session.auth_key = None # Force creating new auth_key
@ -339,7 +360,8 @@ class TelegramBareClient:
client = TelegramBareClient(
session, self.api_id, self.api_hash,
proxy=self._sender.connection.conn.proxy,
timeout=self._sender.connection.get_timeout()
timeout=self._sender.connection.get_timeout(),
loop=self._loop
)
await client.connect(_exported_auth=export_auth, _sync_updates=False)
client._authorized = True # We exported the auth, so we got auth
@ -358,7 +380,8 @@ class TelegramBareClient:
client = TelegramBareClient(
session, self.api_id, self.api_hash,
proxy=self._sender.connection.conn.proxy,
timeout=self._sender.connection.get_timeout()
timeout=self._sender.connection.get_timeout(),
loop=self._loop
)
# This will make use of the new RSA keys for this specific CDN.
@ -383,53 +406,51 @@ class TelegramBareClient:
x.content_related for x in requests):
raise ValueError('You can only invoke requests, not types!')
# TODO Determine the sender to be used (main or a new connection)
sender = self._sender # .clone(), .connect()
# We're on the same connection so no need to pass update_state=None
# to avoid getting messages that we haven't acknowledged yet.
# We should call receive from this thread if there's no background
# thread reading or if the server disconnected us and we're trying
# to reconnect. This is because the read thread may either be
# locked also trying to reconnect or we may be said thread already.
call_receive = self._recv_loop is None
try:
for _ in range(retries):
result = await self._invoke(sender, *requests)
if result is not None:
return result
for retry in range(retries):
result = await self._invoke(call_receive, retry, *requests)
if result is not None:
return result
raise ValueError('Number of retries reached 0.')
finally:
if sender != self._sender:
sender.disconnect() # Close temporary connections
raise ValueError('Number of retries reached 0.')
# Let people use client.invoke(SomeRequest()) instead client(...)
invoke = __call__
async def _invoke(self, sender, *requests):
async def _invoke(self, call_receive, retry, *requests):
try:
# Ensure that we start with no previous errors (i.e. resending)
for x in requests:
x.confirm_received.clear()
x.rpc_error = None
await sender.send(*requests)
while not all(x.confirm_received.is_set() for x in requests):
await sender.receive(update_state=self.updates)
await self._sender.send(*requests)
except TimeoutError:
pass # We will just retry
if not call_receive:
await asyncio.wait(
list(map(lambda x: x.confirm_received.wait(), requests)),
timeout=self._sender.connection.get_timeout(),
loop=self._loop
)
else:
while not all(x.confirm_received.is_set() for x in requests):
await self._sender.receive(update_state=self.updates)
except ConnectionResetError:
if not self._user_connected:
# Only attempt reconnecting if we're authorized
if not self._user_connected or self._reconnect_lock.locked():
# Only attempt reconnecting if the user called connect and not
# reconnecting already.
raise
self._logger.debug('Server disconnected us. Reconnecting and '
'resending request...')
if sender != self._sender:
# TODO Try reconnecting forever too?
await sender.connect()
else:
while self._user_connected and not await self._reconnect():
sleep(0.1) # Retry forever until we can send the request
'resending request... (%d)' % retry)
await self._reconnect()
if not self._sender.is_connected():
await asyncio.sleep(retry + 1, loop=self._loop)
return None
try:
@ -453,7 +474,7 @@ class TelegramBareClient:
)
await self._reconnect(new_dc=e.new_dc)
return await self._invoke(sender, *requests)
return None
except ServerError as e:
# Telegram is having some issues, just retry
@ -468,7 +489,8 @@ class TelegramBareClient:
self._logger.debug(
'Sleep of %d seconds below threshold, sleeping' % e.seconds
)
sleep(e.seconds)
await asyncio.sleep(e.seconds, loop=self._loop)
return None
# Some really basic functionality
@ -671,16 +693,9 @@ class TelegramBareClient:
"""
self.updates.process(await self(GetStateRequest()))
def add_update_handler(self, handler):
async def add_update_handler(self, handler):
"""Adds an update handler (a function which takes a TLObject,
an update, as its parameter) and listens for updates"""
if self.updates.workers is None:
warnings.warn(
"You have not setup any workers, so you won't receive updates."
" Pass update_workers=4 when creating the TelegramClient,"
" or set client.self.updates.workers = 4"
)
self.updates.handlers.append(handler)
def remove_update_handler(self, handler):
@ -695,6 +710,60 @@ class TelegramBareClient:
def _set_connected_and_authorized(self):
self._authorized = True
# TODO self.updates.setup_workers()
if self._recv_loop is None:
self._recv_loop = asyncio.ensure_future(self._recv_loop_impl(), loop=self._loop)
if self._ping_loop is None:
self._ping_loop = asyncio.ensure_future(self._ping_loop_impl(), loop=self._loop)
async def _ping_loop_impl(self):
while self._user_connected:
await self(PingRequest(int.from_bytes(os.urandom(8), 'big', signed=True)))
await asyncio.sleep(self._ping_delay.seconds, loop=self._loop)
self._ping_loop = None
async def _recv_loop_impl(self):
need_reconnect = False
while self._user_connected:
try:
if need_reconnect:
need_reconnect = False
while self._user_connected and not await self._reconnect():
await asyncio.sleep(0.1, loop=self._loop) # Retry forever, this is instant messaging
await self._sender.receive(update_state=self.updates)
except TimeoutError:
# No problem.
pass
except ConnectionError as error:
self._logger.debug(error)
need_reconnect = True
await asyncio.sleep(1, loop=self._loop)
except Exception as error:
# Unknown exception, pass it to the main thread
self._logger.debug(
'[ERROR] Unknown error on the read loop, please report',
error
)
try:
import socks
if isinstance(error, (
socks.GeneralProxyError, socks.ProxyConnectionError
)):
# This is a known error, and it's not related to
# Telegram but rather to the proxy. Disconnect and
# hand it over to the main thread.
self._background_error = error
self.disconnect()
break
except ImportError:
"Not using PySocks, so it can't be a socket error"
# If something strange happens we don't want to enter an
# infinite loop where all we do is raise an exception, so
# add a little sleep to avoid the CPU usage going mad.
await asyncio.sleep(0.1, loop=self._loop)
break
self._recv_loop = None
# endregion

View File

@ -61,6 +61,7 @@ class TelegramClient(TelegramBareClient):
connection_mode=ConnectionMode.TCP_FULL,
proxy=None,
timeout=timedelta(seconds=5),
loop=None,
**kwargs):
"""Initializes the Telegram client with the specified API ID and Hash.
@ -87,6 +88,7 @@ class TelegramClient(TelegramBareClient):
connection_mode=connection_mode,
proxy=proxy,
timeout=timeout,
loop=loop,
**kwargs
)
@ -202,7 +204,7 @@ class TelegramClient(TelegramBareClient):
"""Gets "me" (the self user) which is currently authenticated,
or None if the request fails (hence, not authenticated)."""
try:
return await self(GetUsersRequest([InputUserSelf()]))[0]
return (await self(GetUsersRequest([InputUserSelf()])))[0]
except UnauthorizedError:
return None
@ -313,6 +315,7 @@ class TelegramClient(TelegramBareClient):
reply_to_msg_id=self._get_reply_to(reply_to)
)
result = await self(request)
if isinstance(result, UpdateShortSentMessage):
return Message(
id=result.id,

View File

@ -11,6 +11,15 @@ class MessageContainer(TLObject):
self.content_related = False
self.messages = messages
def to_dict(self, recursive=True):
return {
'content_related': self.content_related,
'messages':
([] if self.messages is None else [
None if x is None else x.to_dict() for x in self.messages
]) if recursive else self.messages,
}
def __bytes__(self):
return struct.pack(
'<Ii', MessageContainer.CONSTRUCTOR_ID, len(self.messages)
@ -25,3 +34,9 @@ class MessageContainer(TLObject):
inner_sequence = reader.read_int()
inner_length = reader.read_int()
yield inner_msg_id, inner_sequence, inner_length
def __str__(self):
return TLObject.pretty_format(self)
def stringify(self):
return TLObject.pretty_format(self, indent=0)

View File

@ -1,4 +1,5 @@
import struct
import logging
from . import TLObject, GzipPacked
@ -11,7 +12,23 @@ class TLMessage(TLObject):
self.msg_id = session.get_new_msg_id()
self.seq_no = session.generate_sequence(request.content_related)
self.request = request
self.container_msg_id = None
logging.getLogger(__name__).debug(self)
def to_dict(self, recursive=True):
return {
'msg_id': self.msg_id,
'seq_no': self.seq_no,
'request': self.request,
'container_msg_id': self.container_msg_id,
}
def __bytes__(self):
body = GzipPacked.gzip_if_smaller(self.request)
return struct.pack('<qii', self.msg_id, self.seq_no, len(body)) + body
def __str__(self):
return TLObject.pretty_format(self)
def stringify(self):
return TLObject.pretty_format(self, indent=0)

View File

@ -1,12 +1,10 @@
from datetime import datetime
from threading import Event
class TLObject:
def __init__(self):
self.request_msg_id = 0 # Long
self.confirm_received = Event()
self.confirm_received = None
self.rpc_error = None
# These should be overrode

View File

@ -1,8 +1,8 @@
import logging
import pickle
import asyncio
from collections import deque
from datetime import datetime
from threading import RLock, Event, Thread
from .tl import types as tl
@ -13,177 +13,72 @@ class UpdateState:
"""
WORKER_POLL_TIMEOUT = 5.0 # Avoid waiting forever on the workers
def __init__(self, workers=None):
"""
:param workers: This integer parameter has three possible cases:
workers is None: Updates will *not* be stored on self.
workers = 0: Another thread is responsible for calling self.poll()
workers > 0: 'workers' background threads will be spawned, any
any of them will invoke all the self.handlers.
"""
self._workers = workers
self._worker_threads = []
def __init__(self, loop=None):
self.handlers = []
self._updates_lock = RLock()
self._updates_available = Event()
self._updates = deque()
self._latest_updates = deque(maxlen=10)
self._loop = loop if loop else asyncio.get_event_loop()
self._logger = logging.getLogger(__name__)
# https://core.telegram.org/api/updates
self._state = tl.updates.State(0, 0, datetime.now(), 0, 0)
def can_poll(self):
"""Returns True if a call to .poll() won't lock"""
return self._updates_available.is_set()
def poll(self, timeout=None):
"""Polls an update or blocks until an update object is available.
If 'timeout is not None', it should be a floating point value,
and the method will 'return None' if waiting times out.
"""
if not self._updates_available.wait(timeout=timeout):
return
with self._updates_lock:
if not self._updates_available.is_set():
return
update = self._updates.popleft()
if not self._updates:
self._updates_available.clear()
if isinstance(update, Exception):
raise update # Some error was set through (surely StopIteration)
return update
def get_workers(self):
return self._workers
def set_workers(self, n):
"""Changes the number of workers running.
If 'n is None', clears all pending updates from memory.
"""
self.stop_workers()
self._workers = n
if n is None:
self._updates.clear()
else:
self.setup_workers()
workers = property(fget=get_workers, fset=set_workers)
def stop_workers(self):
"""Raises "StopIterationException" on the worker threads to stop them,
and also clears all of them off the list
"""
if self._workers:
with self._updates_lock:
# Insert at the beginning so the very next poll causes an error
# on all the worker threads
# TODO Should this reset the pts and such?
for _ in range(self._workers):
self._updates.appendleft(StopIteration())
self._updates_available.set()
for t in self._worker_threads:
t.join()
self._worker_threads.clear()
def setup_workers(self):
if self._worker_threads or not self._workers:
# There already are workers, or workers is None or 0. Do nothing.
return
for i in range(self._workers):
thread = Thread(
target=UpdateState._worker_loop,
name='UpdateWorker{}'.format(i),
daemon=True,
args=(self, i)
)
self._worker_threads.append(thread)
thread.start()
def _worker_loop(self, wid):
while True:
try:
update = self.poll(timeout=UpdateState.WORKER_POLL_TIMEOUT)
# TODO Maybe people can add different handlers per update type
if update:
for handler in self.handlers:
handler(update)
except StopIteration:
break
except Exception as e:
# We don't want to crash a worker thread due to any reason
self._logger.debug(
'[ERROR] Unhandled exception on worker {}'.format(wid), e
)
def handle_update(self, update):
for handler in self.handlers:
asyncio.ensure_future(handler(update), loop=self._loop)
def process(self, update):
"""Processes an update object. This method is normally called by
the library itself.
"""
if self._workers is None:
return # No processing needs to be done if nobody's working
if isinstance(update, tl.updates.State):
self._state = update
return # Nothing else to be done
with self._updates_lock:
if isinstance(update, tl.updates.State):
self._state = update
return # Nothing else to be done
pts = getattr(update, 'pts', self._state.pts)
if hasattr(update, 'pts') and pts <= self._state.pts:
return # We already handled this update
pts = getattr(update, 'pts', self._state.pts)
if hasattr(update, 'pts') and pts <= self._state.pts:
return # We already handled this update
self._state.pts = pts
self._state.pts = pts
# TODO There must be a better way to handle updates rather than
# keeping a queue with the latest updates only, and handling
# the 'pts' correctly should be enough. However some updates
# like UpdateUserStatus (even inside UpdateShort) will be called
# repeatedly very often if invoking anything inside an update
# handler. TODO Figure out why.
"""
client = TelegramClient('anon', api_id, api_hash, update_workers=1)
client.connect()
def handle(u):
client.get_me()
client.add_update_handler(handle)
input('Enter to exit.')
"""
data = pickle.dumps(update.to_dict())
if data in self._latest_updates:
return # Duplicated too
# TODO There must be a better way to handle updates rather than
# keeping a queue with the latest updates only, and handling
# the 'pts' correctly should be enough. However some updates
# like UpdateUserStatus (even inside UpdateShort) will be called
# repeatedly very often if invoking anything inside an update
# handler. TODO Figure out why.
"""
client = TelegramClient('anon', api_id, api_hash, update_workers=1)
client.connect()
def handle(u):
client.get_me()
client.add_update_handler(handle)
input('Enter to exit.')
"""
data = pickle.dumps(update.to_dict())
if data in self._latest_updates:
return # Duplicated too
self._latest_updates.append(data)
self._latest_updates.append(data)
if type(update).SUBCLASS_OF_ID == 0x8af52aac: # crc32(b'Updates')
# Expand "Updates" into "Update", and pass these to callbacks.
# Since .users and .chats have already been processed, we
# don't need to care about those either.
if isinstance(update, tl.UpdateShort):
self.handle_update(update.update)
if type(update).SUBCLASS_OF_ID == 0x8af52aac: # crc32(b'Updates')
# Expand "Updates" into "Update", and pass these to callbacks.
# Since .users and .chats have already been processed, we
# don't need to care about those either.
if isinstance(update, tl.UpdateShort):
self._updates.append(update.update)
self._updates_available.set()
elif isinstance(update, (tl.Updates, tl.UpdatesCombined)):
for upd in update.updates:
self.handle_update(upd)
elif isinstance(update, (tl.Updates, tl.UpdatesCombined)):
self._updates.extend(update.updates)
self._updates_available.set()
elif not isinstance(update, tl.UpdatesTooLong):
# TODO Handle "Updates too long"
self.handle_update(update)
elif not isinstance(update, tl.UpdatesTooLong):
# TODO Handle "Updates too long"
self._updates.append(update)
self._updates_available.set()
elif type(update).SUBCLASS_OF_ID == 0x9f89304e: # crc32(b'Update')
self._updates.append(update)
self._updates_available.set()
else:
self._logger.debug('Ignoring "update" of type {}'.format(
type(update).__name__)
)
elif type(update).SUBCLASS_OF_ID == 0x9f89304e: # crc32(b'Update')
self.handle_update(update)
else:
self._logger.debug('Ignoring "update" of type {}'.format(
type(update).__name__)
)