Update handlers works; it also seems stable

This commit is contained in:
Andrey Egorov 2017-10-22 15:06:36 +03:00
parent 917665852d
commit 780e0ceddf
9 changed files with 300 additions and 265 deletions

View File

@ -5,13 +5,12 @@ import socket
from datetime import timedelta
from io import BytesIO, BufferedWriter
loop = asyncio.get_event_loop()
class TcpClient:
def __init__(self, proxy=None, timeout=timedelta(seconds=5)):
def __init__(self, proxy=None, timeout=timedelta(seconds=5), loop=None):
self.proxy = proxy
self._socket = None
self._loop = loop if loop else asyncio.get_event_loop()
if isinstance(timeout, timedelta):
self.timeout = timeout.seconds
@ -31,7 +30,7 @@ class TcpClient:
else: # tuple, list, etc.
self._socket.set_proxy(*self.proxy)
self._socket.settimeout(self.timeout)
self._socket.setblocking(False)
async def connect(self, ip, port):
"""Connects to the specified IP and port number.
@ -42,20 +41,27 @@ class TcpClient:
else:
mode, address = socket.AF_INET, (ip, port)
timeout = 1
while True:
try:
while not self._socket:
if not self._socket:
self._recreate_socket(mode)
await loop.sock_connect(self._socket, address)
await self._loop.sock_connect(self._socket, address)
break # Successful connection, stop retrying to connect
except ConnectionError:
self._socket = None
await asyncio.sleep(min(timeout, 15))
timeout *= 2
except OSError as e:
# There are some errors that we know how to handle, and
# the loop will allow us to retry
if e.errno == errno.EBADF:
if e.errno in [errno.EBADF, errno.ENOTSOCK, errno.EINVAL]:
# Bad file descriptor, i.e. socket was closed, set it
# to none to recreate it on the next iteration
self._socket = None
await asyncio.sleep(min(timeout, 15))
timeout *= 2
else:
raise
@ -81,13 +87,14 @@ class TcpClient:
raise ConnectionResetError()
try:
await loop.sock_sendall(self._socket, data)
except socket.timeout as e:
await asyncio.wait_for(self._loop.sock_sendall(self._socket, data),
timeout=self.timeout, loop=self._loop)
except asyncio.TimeoutError as e:
raise TimeoutError() from e
except BrokenPipeError:
self._raise_connection_reset()
except OSError as e:
if e.errno == errno.EBADF:
if e.errno in [errno.EBADF, errno.ENOTSOCK, errno.ENETUNREACH, errno.EINVAL, errno.ENOTCONN]:
self._raise_connection_reset()
else:
raise
@ -104,11 +111,12 @@ class TcpClient:
bytes_left = size
while bytes_left != 0:
try:
partial = await loop.sock_recv(self._socket, bytes_left)
except socket.timeout as e:
partial = await asyncio.wait_for(self._loop.sock_recv(self._socket, bytes_left),
timeout=self.timeout, loop=self._loop)
except asyncio.TimeoutError as e:
raise TimeoutError() from e
except OSError as e:
if e.errno == errno.EBADF or e.errno == errno.ENOTSOCK:
if e.errno in [errno.EBADF, errno.ENOTSOCK, errno.ENETUNREACH, errno.EINVAL, errno.ENOTCONN]:
self._raise_connection_reset()
else:
raise

View File

@ -43,13 +43,13 @@ class Connection:
"""
def __init__(self, mode=ConnectionMode.TCP_FULL,
proxy=None, timeout=timedelta(seconds=5)):
proxy=None, timeout=timedelta(seconds=5), loop=None):
self._mode = mode
self._send_counter = 0
self._aes_encrypt, self._aes_decrypt = None, None
# TODO Rename "TcpClient" as some sort of generic socket?
self.conn = TcpClient(proxy=proxy, timeout=timeout)
self.conn = TcpClient(proxy=proxy, timeout=timeout, loop=loop)
# Sending messages
if mode == ConnectionMode.TCP_FULL:
@ -206,7 +206,7 @@ class Connection:
return await self.conn.read(length)
async def _read_obfuscated(self, length):
return await self._aes_decrypt.encrypt(self.conn.read(length))
return self._aes_decrypt.encrypt(await self.conn.read(length))
# endregion

View File

@ -1,6 +1,8 @@
import gzip
import logging
import struct
import asyncio
from asyncio import Event
from .. import helpers as utils
from ..crypto import AES
@ -30,17 +32,15 @@ class MtProtoSender:
in parallel, so thread-safety (hence locking) isn't needed.
"""
def __init__(self, session, connection):
def __init__(self, session, connection, loop=None):
"""Creates a new MtProtoSender configured to send messages through
'connection' and using the parameters from 'session'.
"""
self.session = session
self.connection = connection
self._loop = loop if loop else asyncio.get_event_loop()
self._logger = logging.getLogger(__name__)
# Message IDs that need confirmation
self._need_confirmation = []
# Requests (as msg_id: Message) sent waiting to be received
self._pending_receive = {}
@ -54,12 +54,11 @@ class MtProtoSender:
def disconnect(self):
"""Disconnects from the server"""
self.connection.close()
self._need_confirmation.clear()
self._clear_all_pending()
def clone(self):
"""Creates a copy of this MtProtoSender as a new connection"""
return MtProtoSender(self.session, self.connection.clone())
return MtProtoSender(self.session, self.connection.clone(), self._loop)
# region Send and receive
@ -67,21 +66,23 @@ class MtProtoSender:
"""Sends the specified MTProtoRequest, previously sending any message
which needed confirmation."""
# Prepare the event of every request
for r in requests:
if r.confirm_received is None:
r.confirm_received = Event(loop=self._loop)
else:
r.confirm_received.clear()
# Finally send our packed request(s)
messages = [TLMessage(self.session, r) for r in requests]
self._pending_receive.update({m.msg_id: m for m in messages})
# Pack everything in the same container if we need to send AckRequests
if self._need_confirmation:
messages.append(
TLMessage(self.session, MsgsAck(self._need_confirmation))
)
self._need_confirmation.clear()
if len(messages) == 1:
message = messages[0]
else:
message = TLMessage(self.session, MessageContainer(messages))
for m in messages:
m.container_msg_id = message.msg_id
await self._send_message(message)
@ -115,6 +116,7 @@ class MtProtoSender:
message, remote_msg_id, remote_seq = self._decode_msg(body)
with BinaryReader(message) as reader:
await self._process_msg(remote_msg_id, remote_seq, reader, update_state)
await self._send_acknowledge(remote_msg_id)
# endregion
@ -174,7 +176,6 @@ class MtProtoSender:
"""
# TODO Check salt, session_id and sequence_number
self._need_confirmation.append(msg_id)
code = reader.read_int(signed=False)
reader.seek(-4)
@ -210,14 +211,14 @@ class MtProtoSender:
if code == MsgsAck.CONSTRUCTOR_ID: # may handle the request we wanted
ack = reader.tgread_object()
assert isinstance(ack, MsgsAck)
# Ignore every ack request *unless* when logging out, when it's
# Ignore every ack request *unless* when logging out,
# when it seems to only make sense. We also need to set a non-None
# result since Telegram doesn't send the response for these.
for msg_id in ack.msg_ids:
r = self._pop_request_of_type(msg_id, LogOutRequest)
if r:
r.result = True # Telegram won't send this value
r.confirm_received()
r.confirm_received.set()
self._logger.debug('Message ack confirmed', r)
return True
@ -259,11 +260,29 @@ class MtProtoSender:
if message and isinstance(message.request, t):
return self._pending_receive.pop(msg_id).request
def _pop_requests_of_container(self, container_msg_id):
msgs = [msg for msg in self._pending_receive.values() if msg.container_msg_id == container_msg_id]
requests = [msg.request for msg in msgs]
for msg in msgs:
self._pending_receive.pop(msg.msg_id, None)
return requests
def _clear_all_pending(self):
for r in self._pending_receive.values():
r.confirm_received.set()
r.request.confirm_received.set()
self._pending_receive.clear()
async def _resend_request(self, msg_id):
request = self._pop_request(msg_id)
if request:
self._logger.debug('requests is about to resend')
await self.send(request)
return
requests = self._pop_requests_of_container(msg_id)
if requests:
self._logger.debug('container of requests is about to resend')
await self.send(*requests)
async def _handle_pong(self, msg_id, sequence, reader):
self._logger.debug('Handling pong')
pong = reader.tgread_object()
@ -303,10 +322,9 @@ class MtProtoSender:
self.session.salt = struct.unpack(
'<Q', struct.pack('<q', bad_salt.new_server_salt)
)[0]
self.session.save()
request = self._pop_request(bad_salt.bad_msg_id)
if request:
await self.send(request)
await self._resend_request(bad_salt.bad_msg_id)
return True
@ -322,15 +340,18 @@ class MtProtoSender:
self.session.update_time_offset(correct_msg_id=msg_id)
self._logger.debug('Read Bad Message error: ' + str(error))
self._logger.debug('Attempting to use the correct time offset.')
await self._resend_request(bad_msg.bad_msg_id)
return True
elif bad_msg.error_code == 32:
# msg_seqno too low, so just pump it up by some "large" amount
# TODO A better fix would be to start with a new fresh session ID
self.session._sequence += 64
await self._resend_request(bad_msg.bad_msg_id)
return True
elif bad_msg.error_code == 33:
# msg_seqno too high never seems to happen but just in case
self.session._sequence -= 16
await self._resend_request(bad_msg.bad_msg_id)
return True
else:
raise error
@ -341,7 +362,6 @@ class MtProtoSender:
# TODO For now, simply ack msg_new.answer_msg_id
# Relevant tdesktop source code: https://goo.gl/VvpCC6
await self._send_acknowledge(msg_new.answer_msg_id)
return True
async def _handle_msg_new_detailed_info(self, msg_id, sequence, reader):
@ -350,7 +370,6 @@ class MtProtoSender:
# TODO For now, simply ack msg_new.answer_msg_id
# Relevant tdesktop source code: https://goo.gl/G7DPsR
await self._send_acknowledge(msg_new.answer_msg_id)
return True
async def _handle_new_session_created(self, msg_id, sequence, reader):
@ -378,9 +397,6 @@ class MtProtoSender:
reader.read_int(), reader.tgread_string()
)
# Acknowledge that we received the error
await self._send_acknowledge(request_id)
if request:
request.rpc_error = error
request.confirm_received.set()

View File

@ -1,10 +1,10 @@
import logging
import os
import warnings
import asyncio
from datetime import timedelta, datetime
from hashlib import md5
from io import BytesIO
from time import sleep
from asyncio import Lock
from . import helpers as utils
from .crypto import rsa, CdnDecrypter
@ -17,7 +17,7 @@ from .network import authenticator, MtProtoSender, Connection, ConnectionMode
from .tl import TLObject, Session
from .tl.all_tlobjects import LAYER
from .tl.functions import (
InitConnectionRequest, InvokeWithLayerRequest
InitConnectionRequest, InvokeWithLayerRequest, PingRequest
)
from .tl.functions.auth import (
ImportAuthorizationRequest, ExportAuthorizationRequest
@ -67,6 +67,7 @@ class TelegramBareClient:
connection_mode=ConnectionMode.TCP_FULL,
proxy=None,
timeout=timedelta(seconds=5),
loop=None,
**kwargs):
"""Refer to TelegramClient.__init__ for docs on this method"""
if not api_id or not api_hash:
@ -82,6 +83,8 @@ class TelegramBareClient:
'The given session must be a str or a Session instance.'
)
self._loop = loop if loop else asyncio.get_event_loop()
self.session = session
self.api_id = int(api_id)
self.api_hash = api_hash
@ -92,12 +95,18 @@ class TelegramBareClient:
# that calls .connect(). Every other thread will spawn a new
# temporary connection. The connection on this one is always
# kept open so Telegram can send us updates.
self._sender = MtProtoSender(self.session, Connection(
mode=connection_mode, proxy=proxy, timeout=timeout
))
self._sender = MtProtoSender(
self.session,
Connection(mode=connection_mode, proxy=proxy, timeout=timeout, loop=self._loop),
self._loop
)
self._logger = logging.getLogger(__name__)
# Two coroutines may be calling reconnect() when the connection is lost,
# we only want one to actually perform the reconnection.
self._reconnect_lock = Lock(loop=self._loop)
# Cache "exported" sessions as 'dc_id: Session' not to recreate
# them all the time since generating a new key is a relatively
# expensive operation.
@ -105,7 +114,7 @@ class TelegramBareClient:
# This member will process updates if enabled.
# One may change self.updates.enabled at any later point.
self.updates = UpdateState(workers=None)
self.updates = UpdateState(self._loop)
# Used on connection - the user may modify these and reconnect
kwargs['app_version'] = kwargs.get('app_version', self.__version__)
@ -129,10 +138,11 @@ class TelegramBareClient:
# Uploaded files cache so subsequent calls are instant
self._upload_cache = {}
# Default PingRequest delay
self._last_ping = datetime.now()
self._ping_delay = timedelta(minutes=1)
self._recv_loop = None
self._ping_loop = None
# Default PingRequest delay
self._ping_delay = timedelta(minutes=1)
# endregion
@ -167,6 +177,7 @@ class TelegramBareClient:
self.session.auth_key, self.session.time_offset = \
await authenticator.do_authentication(self._sender.connection)
except BrokenAuthKeyError:
self._user_connected = False
return False
self.session.layer = LAYER
@ -198,12 +209,12 @@ class TelegramBareClient:
# another data center and this would raise UserMigrateError)
# to also assert whether the user is logged in or not.
self._user_connected = True
if _sync_updates and not _cdn:
if _sync_updates and not _cdn and not self._authorized:
try:
await self.sync_updates()
self._set_connected_and_authorized()
except UnauthorizedError:
self._authorized = False
pass
return True
@ -211,7 +222,7 @@ class TelegramBareClient:
# This is fine, probably layer migration
self._logger.debug('Found invalid item, probably migrating', e)
self.disconnect()
return self.connect(
return await self.connect(
_exported_auth=_exported_auth,
_sync_updates=_sync_updates,
_cdn=_cdn
@ -261,7 +272,17 @@ class TelegramBareClient:
"""
if new_dc is None:
# Assume we are disconnected due to some error, so connect again
return await self.connect()
try:
await self._reconnect_lock.acquire()
# Another thread may have connected again, so check that first
if self.is_connected():
return True
return await self.connect()
except ConnectionResetError:
return False
finally:
self._reconnect_lock.release()
else:
self.disconnect()
self.session.auth_key = None # Force creating new auth_key
@ -337,7 +358,8 @@ class TelegramBareClient:
client = TelegramBareClient(
session, self.api_id, self.api_hash,
proxy=self._sender.connection.conn.proxy,
timeout=self._sender.connection.get_timeout()
timeout=self._sender.connection.get_timeout(),
loop=self._loop
)
await client.connect(_exported_auth=export_auth, _sync_updates=False)
client._authorized = True # We exported the auth, so we got auth
@ -356,7 +378,8 @@ class TelegramBareClient:
client = TelegramBareClient(
session, self.api_id, self.api_hash,
proxy=self._sender.connection.conn.proxy,
timeout=self._sender.connection.get_timeout()
timeout=self._sender.connection.get_timeout(),
loop=self._loop
)
# This will make use of the new RSA keys for this specific CDN.
@ -381,55 +404,52 @@ class TelegramBareClient:
x.content_related for x in requests):
raise ValueError('You can only invoke requests, not types!')
# TODO Determine the sender to be used (main or a new connection)
sender = self._sender # .clone(), .connect()
# We should call receive from this thread if there's no background
# thread reading or if the server disconnected us and we're trying
# to reconnect. This is because the read thread may either be
# locked also trying to reconnect or we may be said thread already.
call_receive = self._recv_loop is None
try:
for _ in range(retries):
result = await self._invoke(sender, *requests)
if result is not None:
return result
for retry in range(retries):
result = await self._invoke(call_receive, retry, *requests)
if result is not None:
return result
raise ValueError('Number of retries reached 0.')
finally:
if sender != self._sender:
sender.disconnect() # Close temporary connections
return None
# Let people use client.invoke(SomeRequest()) instead client(...)
invoke = __call__
async def _invoke(self, sender, *requests):
async def _invoke(self, call_receive, retry, *requests):
try:
# Ensure that we start with no previous errors (i.e. resending)
for x in requests:
x.confirm_received.clear()
x.rpc_error = None
await sender.send(*requests)
while not all(x.confirm_received.is_set() for x in requests):
await sender.receive(update_state=self.updates)
await self._sender.send(*requests)
except TimeoutError:
pass # We will just retry
if not call_receive:
await asyncio.wait(
list(map(lambda x: x.confirm_received.wait(), requests)),
timeout=self._sender.connection.get_timeout(),
loop=self._loop
)
else:
while not all(x.confirm_received.is_set() for x in requests):
await self._sender.receive(update_state=self.updates)
except ConnectionResetError:
if not self._user_connected:
# Only attempt reconnecting if we're authorized
if not self._user_connected or self._reconnect_lock.locked():
# Only attempt reconnecting if the user called connect and not
# reconnecting already.
raise
self._logger.debug('Server disconnected us. Reconnecting and '
'resending request...')
if sender != self._sender:
# TODO Try reconnecting forever too?
await sender.connect()
else:
while self._user_connected and not await self._reconnect():
sleep(0.1) # Retry forever until we can send the request
finally:
if sender != self._sender:
sender.disconnect()
'resending request... (%d)' % retry)
await self._reconnect()
if not self._sender.is_connected():
await asyncio.sleep(retry + 1, loop=self._loop)
return None
try:
raise next(x.rpc_error for x in requests if x.rpc_error)
@ -452,7 +472,7 @@ class TelegramBareClient:
)
await self._reconnect(new_dc=e.new_dc)
return await self._invoke(sender, *requests)
return None
except ServerError as e:
# Telegram is having some issues, just retry
@ -467,7 +487,8 @@ class TelegramBareClient:
self._logger.debug(
'Sleep of %d seconds below threshold, sleeping' % e.seconds
)
sleep(e.seconds)
await asyncio.sleep(e.seconds, loop=self._loop)
return None
# Some really basic functionality
@ -670,16 +691,13 @@ class TelegramBareClient:
"""
self.updates.process(await self(GetStateRequest()))
def add_update_handler(self, handler):
async def add_update_handler(self, handler):
"""Adds an update handler (a function which takes a TLObject,
an update, as its parameter) and listens for updates"""
if not self.updates.get_workers:
warnings.warn("There are no update workers running, so adding an update handler will have no effect.")
sync = not self.updates.handlers
self.updates.handlers.append(handler)
if sync:
self.sync_updates()
await self.sync_updates()
def remove_update_handler(self, handler):
self.updates.handlers.remove(handler)
@ -693,6 +711,63 @@ class TelegramBareClient:
def _set_connected_and_authorized(self):
self._authorized = True
# TODO self.updates.setup_workers()
if self._recv_loop is None:
self._recv_loop = asyncio.ensure_future(self._recv_loop_impl(), loop=self._loop)
if self._ping_loop is None:
self._ping_loop = asyncio.ensure_future(self._ping_loop_impl(), loop=self._loop)
async def _ping_loop_impl(self):
while self._user_connected:
await self(PingRequest(int.from_bytes(os.urandom(8), 'big', signed=True)))
await asyncio.sleep(self._ping_delay.seconds, loop=self._loop)
self._ping_loop = None
async def _recv_loop_impl(self):
need_reconnect = False
timeout = 1
while self._user_connected:
try:
if need_reconnect:
need_reconnect = False
while self._user_connected and not await self._reconnect():
await asyncio.sleep(0.1, loop=self._loop) # Retry forever, this is instant messaging
await self._sender.receive(update_state=self.updates)
except TimeoutError:
# No problem.
pass
except ConnectionError as error:
self._logger.debug(error)
need_reconnect = True
await asyncio.sleep(min(timeout, 15), loop=self._loop)
timeout *= 2
except Exception as error:
# Unknown exception, pass it to the main thread
self._logger.debug(
'[ERROR] Unknown error on the read loop, please report',
error
)
try:
import socks
if isinstance(error, (
socks.GeneralProxyError, socks.ProxyConnectionError
)):
# This is a known error, and it's not related to
# Telegram but rather to the proxy. Disconnect and
# hand it over to the main thread.
self._background_error = error
self.disconnect()
break
except ImportError:
"Not using PySocks, so it can't be a socket error"
# If something strange happens we don't want to enter an
# infinite loop where all we do is raise an exception, so
# add a little sleep to avoid the CPU usage going mad.
await asyncio.sleep(0.1, loop=self._loop)
break
timeout = 1
self._recv_loop = None
# endregion

View File

@ -61,6 +61,7 @@ class TelegramClient(TelegramBareClient):
connection_mode=ConnectionMode.TCP_FULL,
proxy=None,
timeout=timedelta(seconds=5),
loop=None,
**kwargs):
"""Initializes the Telegram client with the specified API ID and Hash.
@ -87,6 +88,7 @@ class TelegramClient(TelegramBareClient):
connection_mode=connection_mode,
proxy=proxy,
timeout=timeout,
loop=loop,
**kwargs
)
@ -104,8 +106,9 @@ class TelegramClient(TelegramBareClient):
"""Sends a code request to the specified phone number"""
phone = EntityDatabase.parse_phone(phone) or self._phone
result = await self(SendCodeRequest(phone, self.api_id, self.api_hash))
self._phone = phone
self._phone_code_hash = result.phone_code_hash
if result:
self._phone = phone
self._phone_code_hash = result.phone_code_hash
return result
async def sign_in(self, phone=None, code=None,
@ -169,8 +172,10 @@ class TelegramClient(TelegramBareClient):
'and a password only if an RPCError was raised before.'
)
self._set_connected_and_authorized()
return result.user
if result:
self._set_connected_and_authorized()
return result.user
return result
async def sign_up(self, code, first_name, last_name=''):
"""Signs up to Telegram. Make sure you sent a code request first!"""
@ -182,8 +187,10 @@ class TelegramClient(TelegramBareClient):
last_name=last_name
))
self._set_connected_and_authorized()
return result.user
if result:
self._set_connected_and_authorized()
return result.user
return result
async def log_out(self):
"""Logs out and deletes the current session.
@ -239,7 +246,7 @@ class TelegramClient(TelegramBareClient):
offset_peer=offset_peer,
limit=need if need < float('inf') else 0
))
if not r.dialogs:
if not r or not r.dialogs:
break
for d in r.dialogs:
@ -288,10 +295,12 @@ class TelegramClient(TelegramBareClient):
:return List[telethon.tl.custom.Draft]: A list of open drafts
"""
response = await self(GetAllDraftsRequest())
self.session.process_entities(response)
self.session.generate_sequence(response.seq)
drafts = [Draft._from_update(self, u) for u in response.updates]
return drafts
if response:
self.session.process_entities(response)
self.session.generate_sequence(response.seq)
drafts = [Draft._from_update(self, u) for u in response.updates]
return drafts
return response
async def send_message(self,
entity,
@ -313,6 +322,9 @@ class TelegramClient(TelegramBareClient):
reply_to_msg_id=self._get_reply_to(reply_to)
)
result = await self(request)
if not result:
return result
if isinstance(result, UpdateShortSentMessage):
return Message(
id=result.id,
@ -407,6 +419,8 @@ class TelegramClient(TelegramBareClient):
min_id=min_id,
add_offset=add_offset
))
if not result:
return result
# The result may be a messages slice (not all messages were retrieved)
# or simply a messages TLObject. In the later case, no "count"

View File

@ -11,6 +11,12 @@ class MessageContainer(TLObject):
self.content_related = False
self.messages = messages
def to_dict(self, recursive=True):
return {
'content_related': self.content_related,
'messages': self.messages,
}
def to_bytes(self):
return struct.pack(
'<Ii', MessageContainer.CONSTRUCTOR_ID, len(self.messages)
@ -25,3 +31,9 @@ class MessageContainer(TLObject):
inner_sequence = reader.read_int()
inner_length = reader.read_int()
yield inner_msg_id, inner_sequence, inner_length
def __str__(self):
return TLObject.pretty_format(self)
def stringify(self):
return TLObject.pretty_format(self, indent=0)

View File

@ -1,4 +1,5 @@
import struct
import logging
from . import TLObject, GzipPacked
@ -11,7 +12,23 @@ class TLMessage(TLObject):
self.msg_id = session.get_new_msg_id()
self.seq_no = session.generate_sequence(request.content_related)
self.request = request
self.container_msg_id = None
logging.getLogger(__name__).debug(self)
def to_dict(self, recursive=True):
return {
'msg_id': self.msg_id,
'seq_no': self.seq_no,
'request': self.request,
'container_msg_id': self.container_msg_id,
}
def to_bytes(self):
body = GzipPacked.gzip_if_smaller(self.request)
return struct.pack('<qii', self.msg_id, self.seq_no, len(body)) + body
def __str__(self):
return TLObject.pretty_format(self)
def stringify(self):
return TLObject.pretty_format(self, indent=0)

View File

@ -1,12 +1,10 @@
from datetime import datetime
from threading import Event
class TLObject:
def __init__(self):
self.request_msg_id = 0 # Long
self.confirm_received = Event()
self.confirm_received = None
self.rpc_error = None
# These should be overrode

View File

@ -1,8 +1,8 @@
import logging
import pickle
import asyncio
from collections import deque
from datetime import datetime
from threading import RLock, Event, Thread
from .tl import types as tl
@ -13,177 +13,72 @@ class UpdateState:
"""
WORKER_POLL_TIMEOUT = 5.0 # Avoid waiting forever on the workers
def __init__(self, workers=None):
"""
:param workers: This integer parameter has three possible cases:
workers is None: Updates will *not* be stored on self.
workers = 0: Another thread is responsible for calling self.poll()
workers > 0: 'workers' background threads will be spawned, any
any of them will invoke all the self.handlers.
"""
self._workers = workers
self._worker_threads = []
def __init__(self, loop=None):
self.handlers = []
self._updates_lock = RLock()
self._updates_available = Event()
self._updates = deque()
self._latest_updates = deque(maxlen=10)
self._loop = loop if loop else asyncio.get_event_loop()
self._logger = logging.getLogger(__name__)
# https://core.telegram.org/api/updates
self._state = tl.updates.State(0, 0, datetime.now(), 0, 0)
def can_poll(self):
"""Returns True if a call to .poll() won't lock"""
return self._updates_available.is_set()
def poll(self, timeout=None):
"""Polls an update or blocks until an update object is available.
If 'timeout is not None', it should be a floating point value,
and the method will 'return None' if waiting times out.
"""
if not self._updates_available.wait(timeout=timeout):
return
with self._updates_lock:
if not self._updates_available.is_set():
return
update = self._updates.popleft()
if not self._updates:
self._updates_available.clear()
if isinstance(update, Exception):
raise update # Some error was set through (surely StopIteration)
return update
def get_workers(self):
return self._workers
def set_workers(self, n):
"""Changes the number of workers running.
If 'n is None', clears all pending updates from memory.
"""
self.stop_workers()
self._workers = n
if n is None:
self._updates.clear()
else:
self.setup_workers()
workers = property(fget=get_workers, fset=set_workers)
def stop_workers(self):
"""Raises "StopIterationException" on the worker threads to stop them,
and also clears all of them off the list
"""
if self._workers:
with self._updates_lock:
# Insert at the beginning so the very next poll causes an error
# on all the worker threads
# TODO Should this reset the pts and such?
for _ in range(self._workers):
self._updates.appendleft(StopIteration())
self._updates_available.set()
for t in self._worker_threads:
t.join()
self._worker_threads.clear()
def setup_workers(self):
if self._worker_threads or not self._workers:
# There already are workers, or workers is None or 0. Do nothing.
return
for i in range(self._workers):
thread = Thread(
target=UpdateState._worker_loop,
name='UpdateWorker{}'.format(i),
daemon=True,
args=(self, i)
)
self._worker_threads.append(thread)
thread.start()
def _worker_loop(self, wid):
while True:
try:
update = self.poll(timeout=UpdateState.WORKER_POLL_TIMEOUT)
# TODO Maybe people can add different handlers per update type
if update:
for handler in self.handlers:
handler(update)
except StopIteration:
break
except Exception as e:
# We don't want to crash a worker thread due to any reason
self._logger.debug(
'[ERROR] Unhandled exception on worker {}'.format(wid), e
)
def handle_update(self, update):
for handler in self.handlers:
asyncio.ensure_future(handler(update), loop=self._loop)
def process(self, update):
"""Processes an update object. This method is normally called by
the library itself.
"""
if self._workers is None:
return # No processing needs to be done if nobody's working
if isinstance(update, tl.updates.State):
self._state = update
return # Nothing else to be done
with self._updates_lock:
if isinstance(update, tl.updates.State):
self._state = update
return # Nothing else to be done
pts = getattr(update, 'pts', self._state.pts)
if hasattr(update, 'pts') and pts <= self._state.pts:
return # We already handled this update
pts = getattr(update, 'pts', self._state.pts)
if hasattr(update, 'pts') and pts <= self._state.pts:
return # We already handled this update
self._state.pts = pts
self._state.pts = pts
# TODO There must be a better way to handle updates rather than
# keeping a queue with the latest updates only, and handling
# the 'pts' correctly should be enough. However some updates
# like UpdateUserStatus (even inside UpdateShort) will be called
# repeatedly very often if invoking anything inside an update
# handler. TODO Figure out why.
"""
client = TelegramClient('anon', api_id, api_hash, update_workers=1)
client.connect()
def handle(u):
client.get_me()
client.add_update_handler(handle)
input('Enter to exit.')
"""
data = pickle.dumps(update.to_dict())
if data in self._latest_updates:
return # Duplicated too
# TODO There must be a better way to handle updates rather than
# keeping a queue with the latest updates only, and handling
# the 'pts' correctly should be enough. However some updates
# like UpdateUserStatus (even inside UpdateShort) will be called
# repeatedly very often if invoking anything inside an update
# handler. TODO Figure out why.
"""
client = TelegramClient('anon', api_id, api_hash, update_workers=1)
client.connect()
def handle(u):
client.get_me()
client.add_update_handler(handle)
input('Enter to exit.')
"""
data = pickle.dumps(update.to_dict())
if data in self._latest_updates:
return # Duplicated too
self._latest_updates.append(data)
self._latest_updates.append(data)
if type(update).SUBCLASS_OF_ID == 0x8af52aac: # crc32(b'Updates')
# Expand "Updates" into "Update", and pass these to callbacks.
# Since .users and .chats have already been processed, we
# don't need to care about those either.
if isinstance(update, tl.UpdateShort):
self.handle_update(update.update)
if type(update).SUBCLASS_OF_ID == 0x8af52aac: # crc32(b'Updates')
# Expand "Updates" into "Update", and pass these to callbacks.
# Since .users and .chats have already been processed, we
# don't need to care about those either.
if isinstance(update, tl.UpdateShort):
self._updates.append(update.update)
self._updates_available.set()
elif isinstance(update, (tl.Updates, tl.UpdatesCombined)):
for upd in update.updates:
self.handle_update(upd)
elif isinstance(update, (tl.Updates, tl.UpdatesCombined)):
self._updates.extend(update.updates)
self._updates_available.set()
elif not isinstance(update, tl.UpdatesTooLong):
# TODO Handle "Updates too long"
self.handle_update(update)
elif not isinstance(update, tl.UpdatesTooLong):
# TODO Handle "Updates too long"
self._updates.append(update)
self._updates_available.set()
elif type(update).SUBCLASS_OF_ID == 0x9f89304e: # crc32(b'Update')
self._updates.append(update)
self._updates_available.set()
else:
self._logger.debug('Ignoring "update" of type {}'.format(
type(update).__name__)
)
elif type(update).SUBCLASS_OF_ID == 0x9f89304e: # crc32(b'Update')
self.handle_update(update)
else:
self._logger.debug('Ignoring "update" of type {}'.format(
type(update).__name__)
)