Merge branch 'asyncio' into asyncio-upstream

* asyncio:
  Not need to save (salt is out of DB)
  Very rare exception in the case of reconnect
  updates_handler is out from MtProtoSender to gc works properly; unauth_handler log format fix
  Memory leaks fix
  Pretty format of TLObject's
  More accurate clear pendings
  Another attempt to prevent duplicates
  Handle updates and other refactoring
  SocketClosed exception
  Refactoring of TcpClient
  Socket OSError logging
  More aggressive catching network errors
  No route to host catched + other errno to reconnect

# Conflicts (resolved):
#	telethon/extensions/tcp_client.py
#	telethon/network/mtproto_sender.py
#	telethon/telegram_bare_client.py
#	telethon/tl/session.py
This commit is contained in:
Andrey Egorov 2018-06-14 14:34:08 +03:00
commit 43a0226b33
9 changed files with 550 additions and 484 deletions

View File

@ -6,32 +6,33 @@ import asyncio
import errno
import logging
import socket
import time
from datetime import timedelta
from io import BytesIO, BufferedWriter
MAX_TIMEOUT = 15 # in seconds
CONN_RESET_ERRNOS = {
errno.EBADF, errno.ENOTSOCK, errno.ENETUNREACH,
errno.EINVAL, errno.ENOTCONN
errno.EINVAL, errno.ENOTCONN, errno.EHOSTUNREACH,
errno.ECONNREFUSED, errno.ECONNRESET, errno.ECONNABORTED,
errno.ENETDOWN, errno.ENETRESET, errno.ECONNABORTED,
errno.EHOSTDOWN, errno.EPIPE, errno.ESHUTDOWN
}
# catched: EHOSTUNREACH, ECONNREFUSED, ECONNRESET, ENETUNREACH
# ConnectionError: EPIPE, ESHUTDOWN, ECONNABORTED, ECONNREFUSED, ECONNRESET
try:
import socks
except ImportError:
socks = None
MAX_TIMEOUT = 15 # in seconds
CONN_RESET_ERRNOS = {
errno.EBADF, errno.ENOTSOCK, errno.ENETUNREACH,
errno.EINVAL, errno.ENOTCONN
}
__log__ = logging.getLogger(__name__)
class TcpClient:
"""A simple TCP client to ease the work with sockets and proxies."""
class SocketClosed(ConnectionError):
pass
def __init__(self, proxy=None, timeout=timedelta(seconds=5), loop=None):
"""
Initializes the TCP client.
@ -42,6 +43,8 @@ class TcpClient:
self.proxy = proxy
self._socket = None
self._loop = loop if loop else asyncio.get_event_loop()
self._closed = asyncio.Event(loop=self._loop)
self._closed.set()
if isinstance(timeout, timedelta):
self.timeout = timeout.seconds
@ -76,41 +79,28 @@ class TcpClient:
else:
mode, address = socket.AF_INET, (ip, port)
timeout = 1
while True:
try:
if not self._socket:
self._recreate_socket(mode)
try:
if not self._socket:
self._recreate_socket(mode)
await self._loop.sock_connect(self._socket, address)
break # Successful connection, stop retrying to connect
except ConnectionError:
self._socket = None
await asyncio.sleep(timeout)
timeout = min(timeout * 2, MAX_TIMEOUT)
except OSError as e:
__log__.info('OSError "%s" raised while connecting', e)
# Stop retrying to connect if proxy connection error occurred
if socks and isinstance(e, socks.ProxyConnectionError):
raise
# There are some errors that we know how to handle, and
# the loop will allow us to retry
if e.errno in (errno.EBADF, errno.ENOTSOCK, errno.EINVAL,
errno.ECONNREFUSED, # Windows-specific follow
getattr(errno, 'WSAEACCES', None)):
# Bad file descriptor, i.e. socket was closed, set it
# to none to recreate it on the next iteration
self._socket = None
await asyncio.sleep(timeout)
timeout *= 2
if timeout > MAX_TIMEOUT:
raise
else:
raise
await asyncio.wait_for(
self._loop.sock_connect(self._socket, address),
timeout=self.timeout,
loop=self._loop
)
self._closed.clear()
except asyncio.TimeoutError as e:
raise TimeoutError() from e
except OSError as e:
if e.errno in CONN_RESET_ERRNOS:
self._raise_connection_reset(e)
else:
raise
def _get_connected(self):
"""Determines whether the client is connected or not."""
return self._socket is not None and self._socket.fileno() >= 0
return not self._closed.is_set()
connected = property(fget=_get_connected)
@ -118,12 +108,29 @@ class TcpClient:
"""Closes the connection."""
try:
if self._socket is not None:
self._socket.shutdown(socket.SHUT_RDWR)
if self.connected:
self._socket.shutdown(socket.SHUT_RDWR)
self._socket.close()
except OSError:
pass # Ignore ENOTCONN, EBADF, and any other error when closing
finally:
self._socket = None
self._closed.set()
async def _wait_close(self, coro):
done, running = await asyncio.wait(
[coro, self._closed.wait()],
timeout=self.timeout,
return_when=asyncio.FIRST_COMPLETED,
loop=self._loop
)
for r in running:
r.cancel()
if not self.connected:
raise self.SocketClosed()
if not done:
raise TimeoutError()
return done.pop().result()
async def write(self, data):
"""
@ -131,21 +138,12 @@ class TcpClient:
:param data: the data to send.
"""
if self._socket is None:
self._raise_connection_reset(None)
if not self.connected:
raise ConnectionResetError('No connection')
try:
await asyncio.wait_for(
self.sock_sendall(data),
timeout=self.timeout,
loop=self._loop
)
except asyncio.TimeoutError as e:
__log__.debug('socket.timeout "%s" while writing data', e)
raise TimeoutError() from e
except ConnectionError as e:
__log__.info('ConnectionError "%s" while writing data', e)
self._raise_connection_reset(e)
await self._wait_close(self.sock_sendall(data))
except self.SocketClosed:
raise ConnectionResetError('Socket has closed')
except OSError as e:
__log__.info('OSError "%s" while writing data', e)
if e.errno in CONN_RESET_ERRNOS:
@ -160,21 +158,15 @@ class TcpClient:
:param size: the size of the block to be read.
:return: the read data with len(data) == size.
"""
if self._socket is None:
self._raise_connection_reset(None)
with BufferedWriter(BytesIO(), buffer_size=size) as buffer:
bytes_left = size
partial = b''
while bytes_left != 0:
if not self.connected:
raise ConnectionResetError('No connection')
try:
if self._socket is None:
self._raise_connection_reset()
partial = await asyncio.wait_for(
self.sock_recv(bytes_left),
timeout=self.timeout,
loop=self._loop
)
except asyncio.TimeoutError as e:
partial = await self._wait_close(self.sock_recv(bytes_left))
except TimeoutError as e:
# These are somewhat common if the server has nothing
# to send to us, so use a lower logging priority.
if bytes_left < size:
@ -187,10 +179,9 @@ class TcpClient:
'socket.timeout "%s" while reading data', e
)
raise TimeoutError() from e
except ConnectionError as e:
__log__.info('ConnectionError "%s" while reading data', e)
self._raise_connection_reset(e)
raise
except self.SocketClosed:
raise ConnectionResetError('Socket has closed while reading data')
except OSError as e:
if e.errno != errno.EBADF:
# Ignore bad file descriptor while closing
@ -202,7 +193,7 @@ class TcpClient:
raise
if len(partial) == 0:
self._raise_connection_reset(None)
self._raise_connection_reset('No data on read')
buffer.write(partial)
bytes_left -= len(partial)
@ -211,10 +202,12 @@ class TcpClient:
buffer.flush()
return buffer.raw.getvalue()
def _raise_connection_reset(self, original):
"""Disconnects the client and raises ConnectionResetError."""
def _raise_connection_reset(self, error):
description = error if isinstance(error, str) else str(error)
if isinstance(error, str):
error = Exception(error)
self.close() # Connection reset -> flag as socket closed
raise ConnectionResetError('The server has closed the connection.') from original
raise ConnectionResetError(description) from error
# due to new https://github.com/python/cpython/pull/4386
def sock_recv(self, n):
@ -225,7 +218,7 @@ class TcpClient:
def _sock_recv(self, fut, registered_fd, n):
if registered_fd is not None:
self._loop.remove_reader(registered_fd)
if fut.cancelled():
if fut.cancelled() or self._socket is None:
return
try:
@ -249,7 +242,7 @@ class TcpClient:
def _sock_sendall(self, fut, registered_fd, data):
if registered_fd:
self._loop.remove_writer(registered_fd)
if fut.cancelled():
if fut.cancelled() or self._socket is None:
return
try:

View File

@ -14,11 +14,11 @@ from ..errors import (
from ..extensions import BinaryReader
from ..tl import TLMessage, MessageContainer, GzipPacked
from ..tl.all_tlobjects import tlobjects
from ..tl.functions import InvokeAfterMsgRequest
from ..tl.functions.auth import LogOutRequest
from ..tl.types import (
MsgsAck, Pong, BadServerSalt, BadMsgNotification, FutureSalts,
MsgNewDetailedInfo, NewSessionCreated, MsgDetailedInfo
MsgNewDetailedInfo, MsgDetailedInfo, MsgsStateReq, MsgResendReq,
MsgsAllInfo, MsgsStateInfo, RpcError
)
__log__ = logging.getLogger(__name__)
@ -56,7 +56,8 @@ class MtProtoSender:
# receiving other request from the main thread (e.g. an update arrives
# and we need to process it) we must ensure that only one is calling
# receive at a given moment, since the receive step is fragile.
self._recv_lock = asyncio.Lock()
self._read_lock = asyncio.Lock(loop=self._loop)
self._write_lock = asyncio.Lock(loop=self._loop)
# Requests (as msg_id: Message) sent waiting to be received
self._pending_receive = {}
@ -73,11 +74,12 @@ class MtProtoSender:
"""
return self.connection.is_connected()
def disconnect(self):
def disconnect(self, clear_pendings=True):
"""Disconnects from the server."""
__log__.info('Disconnecting MtProtoSender...')
self.connection.close()
self._clear_all_pending()
if clear_pendings:
self._clear_all_pending()
# region Send and receive
@ -90,6 +92,7 @@ class MtProtoSender:
:param ordered: whether the requests should be invoked in the
order in which they appear or they can be executed
in arbitrary order in the server.
:return: a list of msg_ids which are correspond to sent requests.
"""
if not utils.is_list_like(requests):
requests = (requests,)
@ -111,6 +114,7 @@ class MtProtoSender:
messages = [TLMessage(self.session, r) for r in requests]
self._pending_receive.update({m.msg_id: m for m in messages})
msg_ids = [m.msg_id for m in messages]
__log__.debug('Sending requests with IDs: %s', ', '.join(
'{}: {}'.format(m.request.__class__.__name__, m.msg_id)
@ -128,12 +132,18 @@ class MtProtoSender:
m.container_msg_id = message.msg_id
await self._send_message(message)
return msg_ids
def forget_pendings(self, msg_ids):
for msg_id in msg_ids:
if msg_id in self._pending_receive:
del self._pending_receive[msg_id]
async def _send_acknowledge(self, msg_id):
"""Sends a message acknowledge for the given msg_id."""
await self._send_message(TLMessage(self.session, MsgsAck([msg_id])))
async def receive(self, update_state):
async def receive(self, updates_handler):
"""
Receives a single message from the connected endpoint.
@ -144,21 +154,13 @@ class MtProtoSender:
Any unhandled object (likely updates) will be passed to
update_state.process(TLObject).
:param update_state:
the UpdateState that will process all the received
:param updates_handler:
the handler that will process all the received
Update and Updates objects.
"""
if self._recv_lock.locked():
with await self._recv_lock:
# Don't busy wait, acquire it but return because there's
# already a receive running and we don't want another one.
# It would lock until Telegram sent another update even if
# the current receive already received the expected response.
return
await self._read_lock.acquire()
try:
with await self._recv_lock:
body = await self.connection.recv()
body = await self.connection.recv()
except (BufferError, InvalidChecksumError):
# TODO BufferError, we should spot the cause...
# "No more bytes left"; something wrong happened, clear
@ -172,11 +174,12 @@ class MtProtoSender:
len(self._pending_receive))
self._clear_all_pending()
return
finally:
self._read_lock.release()
message, remote_msg_id, remote_seq = self._decode_msg(body)
with BinaryReader(message) as reader:
await self._process_msg(remote_msg_id, remote_seq, reader, update_state)
await self._send_acknowledge(remote_msg_id)
await self._process_msg(remote_msg_id, remote_seq, reader, updates_handler)
# endregion
@ -188,7 +191,11 @@ class MtProtoSender:
:param message: the TLMessage to be sent.
"""
await self.connection.send(helpers.pack_message(self.session, message))
await self._write_lock.acquire()
try:
await self.connection.send(helpers.pack_message(self.session, message))
finally:
self._write_lock.release()
def _decode_msg(self, body):
"""
@ -206,14 +213,14 @@ class MtProtoSender:
with BinaryReader(body) as reader:
return helpers.unpack_message(self.session, reader)
async def _process_msg(self, msg_id, sequence, reader, state):
async def _process_msg(self, msg_id, sequence, reader, updates_handler):
"""
Processes the message read from the network inside reader.
:param msg_id: the ID of the message.
:param sequence: the sequence of the message.
:param reader: the BinaryReader that contains the message.
:param state: the current UpdateState.
:param updates_handler: the handler to process Update and Updates objects.
:return: true if the message was handled correctly, false otherwise.
"""
# TODO Check salt, session_id and sequence_number
@ -224,15 +231,16 @@ class MtProtoSender:
# These are a bit of special case, not yet generated by the code gen
if code == 0xf35c6d01: # rpc_result, (response of an RPC call)
__log__.debug('Processing Remote Procedure Call result')
await self._send_acknowledge(msg_id)
return await self._handle_rpc_result(msg_id, sequence, reader)
if code == MessageContainer.CONSTRUCTOR_ID:
__log__.debug('Processing container result')
return await self._handle_container(msg_id, sequence, reader, state)
return await self._handle_container(msg_id, sequence, reader, updates_handler)
if code == GzipPacked.CONSTRUCTOR_ID:
__log__.debug('Processing gzipped result')
return await self._handle_gzip_packed(msg_id, sequence, reader, state)
return await self._handle_gzip_packed(msg_id, sequence, reader, updates_handler)
if code not in tlobjects:
__log__.warning(
@ -250,6 +258,14 @@ class MtProtoSender:
if isinstance(obj, BadServerSalt):
return await self._handle_bad_server_salt(msg_id, sequence, obj)
if isinstance(obj, (MsgsStateReq, MsgResendReq)):
# just answer we don't know anything
return await self._handle_msgs_state_forgotten(msg_id, sequence, obj)
if isinstance(obj, MsgsAllInfo):
# not interesting now
return True
if isinstance(obj, BadMsgNotification):
return await self._handle_bad_msg_notification(msg_id, sequence, obj)
@ -259,11 +275,8 @@ class MtProtoSender:
if isinstance(obj, MsgNewDetailedInfo):
return await self._handle_msg_new_detailed_info(msg_id, sequence, obj)
if isinstance(obj, NewSessionCreated):
return await self._handle_new_session_created(msg_id, sequence, obj)
if isinstance(obj, MsgsAck): # may handle the request we wanted
# Ignore every ack request *unless* when logging out, when it's
# Ignore every ack request *unless* when logging out,
# when it seems to only make sense. We also need to set a non-None
# result since Telegram doesn't send the response for these.
for msg_id in obj.msg_ids:
@ -284,8 +297,9 @@ class MtProtoSender:
# If the object isn't any of the above, then it should be an Update.
self.session.process_entities(obj)
if state:
state.process(obj)
await self._send_acknowledge(msg_id)
if updates_handler:
updates_handler(obj)
return True
@ -372,22 +386,24 @@ class MtProtoSender:
return True
async def _handle_container(self, msg_id, sequence, reader, state):
async def _handle_container(self, msg_id, sequence, reader, updates_handler):
"""
Handles a MessageContainer response.
:param msg_id: the ID of the message.
:param sequence: the sequence of the message.
:param reader: the reader containing the MessageContainer.
:param updates_handler: handler to handle Update and Updates objects.
:return: true, as it always succeeds.
"""
__log__.debug('Handling container')
for inner_msg_id, _, inner_len in MessageContainer.iter_read(reader):
begin_position = reader.tell_position()
# Note that this code is IMPORTANT for skipping RPC results of
# lost requests (i.e., ones from the previous connection session)
try:
if not await self._process_msg(inner_msg_id, sequence, reader, state):
if not await self._process_msg(inner_msg_id, sequence, reader, updates_handler):
reader.set_position(begin_position + inner_len)
except:
# If any error is raised, something went wrong; skip the packet
@ -406,7 +422,6 @@ class MtProtoSender:
:return: true, as it always succeeds.
"""
self.session.salt = bad_salt.new_server_salt
self.session.save()
# "the bad_server_salt response is received with the
# correct salt, and the message is to be re-sent with it"
@ -414,6 +429,10 @@ class MtProtoSender:
return True
async def _handle_msgs_state_forgotten(self, msg_id, sequence, req):
await self._send_message(TLMessage(self.session, MsgsStateInfo(msg_id, chr(1) * len(req.msg_ids))))
return True
async def _handle_bad_msg_notification(self, msg_id, sequence, bad_msg):
"""
Handles a BadMessageError response.
@ -476,19 +495,6 @@ class MtProtoSender:
await self._send_acknowledge(msg_new.answer_msg_id)
return True
async def _handle_new_session_created(self, msg_id, sequence, new_session):
"""
Handles a NewSessionCreated response.
:param msg_id: the ID of the message.
:param sequence: the sequence of the message.
:param reader: the reader containing the NewSessionCreated.
:return: true, as it always succeeds.
"""
self.session.salt = new_session.server_salt
# TODO https://goo.gl/LMyN7A
return True
async def _handle_rpc_result(self, msg_id, sequence, reader):
"""
Handles a RPCResult response.
@ -507,7 +513,7 @@ class MtProtoSender:
__log__.debug('Received response for request with ID %d', request_id)
request = self._pop_request(request_id)
if inner_code == 0x2144ca19: # RPC Error
if inner_code == RpcError.CONSTRUCTOR_ID: # RPC Error
reader.seek(4)
if self.session.report_errors and request:
error = rpc_message_to_error(
@ -530,6 +536,7 @@ class MtProtoSender:
return True # All contents were read okay
elif request:
__log__.debug('Reading request response')
if inner_code == GzipPacked.CONSTRUCTOR_ID:
with BinaryReader(GzipPacked.read(reader)) as compressed_reader:
request.on_response(compressed_reader)
@ -566,16 +573,18 @@ class MtProtoSender:
)
return False
async def _handle_gzip_packed(self, msg_id, sequence, reader, state):
async def _handle_gzip_packed(self, msg_id, sequence, reader, updates_handler):
"""
Handles a GzipPacked response.
:param msg_id: the ID of the message.
:param sequence: the sequence of the message.
:param reader: the reader containing the GzipPacked.
:param updates_handler: the handler to process Update and Updates objects.
:return: the result of processing the packed message.
"""
__log__.debug('Handling gzip packed data')
with BinaryReader(GzipPacked.read(reader)) as compressed_reader:
return await self._process_msg(msg_id, sequence, compressed_reader, state)
return await self._process_msg(msg_id, sequence, compressed_reader, updates_handler)
# endregion

View File

@ -67,6 +67,22 @@ class Session(ABC):
"""
raise NotImplementedError
@property
@abstractmethod
def user_id(self):
"""
Returns an ``user_id`` which the session related to.
"""
raise NotImplementedError
@user_id.setter
@abstractmethod
def user_id(self, value):
"""
Sets the ``user_id`` which the session related to.
"""
raise NotImplementedError
@abstractmethod
def get_update_state(self, entity_id):
"""
@ -94,7 +110,7 @@ class Session(ABC):
"""
@abstractmethod
def save(self):
async def save(self):
"""
Called whenever important properties change. It should
make persist the relevant session information to disk.
@ -102,7 +118,7 @@ class Session(ABC):
raise NotImplementedError
@abstractmethod
def delete(self):
async def delete(self):
"""
Called upon client.log_out(). Should delete the stored
information from disk since it's not valid anymore.
@ -125,7 +141,7 @@ class Session(ABC):
raise NotImplementedError
@abstractmethod
def get_input_entity(self, key):
async def get_input_entity(self, key):
"""
Turns the given key into an ``InputPeer`` (e.g. ``InputPeerUser``).
The library uses this method whenever an ``InputPeer`` is needed
@ -135,7 +151,7 @@ class Session(ABC):
raise NotImplementedError
@abstractmethod
def cache_file(self, md5_digest, file_size, instance):
async def cache_file(self, md5_digest, file_size, instance):
"""
Caches the given file information persistently, so that it
doesn't need to be re-uploaded in case the file is used again.
@ -146,7 +162,7 @@ class Session(ABC):
raise NotImplementedError
@abstractmethod
def get_file(self, md5_digest, file_size, cls):
async def get_file(self, md5_digest, file_size, cls):
"""
Returns an instance of ``cls`` if the ``md5_digest`` and ``file_size``
match an existing saved record. The class will either be an

View File

@ -32,6 +32,7 @@ class MemorySession(Session):
self._server_address = None
self._port = None
self._auth_key = None
self._user_id = None
self._files = {}
self._entities = set()
@ -58,6 +59,14 @@ class MemorySession(Session):
def auth_key(self, value):
self._auth_key = value
@property
def user_id(self):
return self._user_id
@user_id.setter
def user_id(self, value):
self._user_id = value
def get_update_state(self, entity_id):
return self._update_states.get(entity_id, None)
@ -67,10 +76,10 @@ class MemorySession(Session):
def close(self):
pass
def save(self):
async def save(self):
pass
def delete(self):
async def delete(self):
pass
def _entity_values_to_row(self, id, hash, username, phone, name):
@ -170,7 +179,7 @@ class MemorySession(Session):
except StopIteration:
pass
def get_input_entity(self, key):
async def get_input_entity(self, key):
try:
if key.SUBCLASS_OF_ID in (0xc91c90b6, 0xe669bf46, 0x40f202fd):
# hex(crc32(b'InputPeer', b'InputUser' and b'InputChannel'))
@ -215,14 +224,14 @@ class MemorySession(Session):
else:
raise ValueError('Could not find input entity with key ', key)
def cache_file(self, md5_digest, file_size, instance):
async def cache_file(self, md5_digest, file_size, instance):
if not isinstance(instance, (InputDocument, InputPhoto)):
raise TypeError('Cannot cache %s instance' % type(instance))
key = (md5_digest, file_size, _SentFileType.from_type(instance))
value = (instance.id, instance.access_hash)
self._files[key] = value
def get_file(self, md5_digest, file_size, cls):
async def get_file(self, md5_digest, file_size, cls):
key = (md5_digest, file_size, _SentFileType.from_type(cls))
try:
return cls(self._files[key])

View File

@ -213,7 +213,7 @@ class SQLiteSession(MemorySession):
))
c.close()
def get_update_state(self, entity_id):
async def get_update_state(self, entity_id):
c = self._cursor()
row = c.execute('select pts, qts, date, seq from update_state '
'where id = ?', (entity_id,)).fetchone()
@ -223,7 +223,7 @@ class SQLiteSession(MemorySession):
date = datetime.datetime.utcfromtimestamp(date)
return types.updates.State(pts, qts, date, seq, unread_count=0)
def set_update_state(self, entity_id, state):
async def set_update_state(self, entity_id, state):
c = self._cursor()
c.execute('insert or replace into update_state values (?,?,?,?,?)',
(entity_id, state.pts, state.qts,
@ -231,7 +231,7 @@ class SQLiteSession(MemorySession):
c.close()
self.save()
def save(self):
async def save(self):
"""Saves the current session object as session_user_id.session"""
self._conn.commit()
@ -248,7 +248,7 @@ class SQLiteSession(MemorySession):
self._conn.close()
self._conn = None
def delete(self):
async def delete(self):
"""Deletes the current session file"""
if self.filename == ':memory:':
return True
@ -319,7 +319,7 @@ class SQLiteSession(MemorySession):
# File processing
def get_file(self, md5_digest, file_size, cls):
async def get_file(self, md5_digest, file_size, cls):
tuple_ = self._cursor().execute(
'select id, hash from sent_files '
'where md5_digest = ? and file_size = ? and type = ?',
@ -329,7 +329,7 @@ class SQLiteSession(MemorySession):
# Both allowed classes have (id, access_hash) as parameters
return cls(tuple_[0], tuple_[1])
def cache_file(self, md5_digest, file_size, instance):
async def cache_file(self, md5_digest, file_size, instance):
if not isinstance(instance, (InputDocument, InputPhoto)):
raise TypeError('Cannot cache %s instance' % type(instance))

View File

@ -1,8 +1,8 @@
import asyncio
import logging
import os
from asyncio import Lock
from datetime import timedelta
from asyncio import Lock, Event
from datetime import timedelta, datetime
import platform
from . import version, utils
from .crypto import rsa
@ -13,7 +13,7 @@ from .errors import (
RpcCallFailError
)
from .network import authenticator, MtProtoSender, ConnectionTcpFull
from .sessions import Session, SQLiteSession
from .sessions import Session
from .tl import TLObject
from .tl.all_tlobjects import LAYER
from .tl.functions import (
@ -25,10 +25,17 @@ from .tl.functions.auth import (
from .tl.functions.help import (
GetCdnConfigRequest, GetConfigRequest
)
from .tl.functions.updates import GetStateRequest
from .tl.functions.updates import GetStateRequest, GetDifferenceRequest
from .tl.types import (
Pong, PeerUser, PeerChat, Message, Updates, UpdateShort, UpdateNewChannelMessage, UpdateEditChannelMessage,
UpdateDeleteChannelMessages, UpdateChannelTooLong, UpdateNewMessage, NewSessionCreated, UpdatesTooLong,
UpdateShortSentMessage, MessageEmpty, UpdateShortMessage, UpdateShortChatMessage, UpdatesCombined
)
from .tl.types.auth import ExportedAuthorization
from .update_state import UpdateState
from .tl.types.messages import AffectedMessages, AffectedHistory
from .tl.types.updates import DifferenceEmpty, DifferenceTooLong, DifferenceSlice
MAX_TIMEOUT = 15 # in seconds
DEFAULT_DC_ID = 4
DEFAULT_IPV4_IP = '149.154.167.51'
DEFAULT_IPV6_IP = '[2001:67c:4e8:f002::a]'
@ -71,8 +78,11 @@ class TelegramBareClient:
use_ipv6=False,
proxy=None,
timeout=timedelta(seconds=5),
ping_delay=timedelta(minutes=1),
update_handler=None,
unauthorized_handler=None,
loop=None,
report_errors=True,
report_errors=None,
device_model=None,
system_version=None,
app_version=None,
@ -87,12 +97,8 @@ class TelegramBareClient:
self._use_ipv6 = use_ipv6
# Determine what session object we have
if isinstance(session, str) or session is None:
session = SQLiteSession(session)
elif not isinstance(session, Session):
raise TypeError(
'The given session must be a str or a Session instance.'
)
if not isinstance(session, Session):
raise TypeError('The given session must be a Session instance.')
self._loop = loop if loop else asyncio.get_event_loop()
@ -105,7 +111,8 @@ class TelegramBareClient:
DEFAULT_PORT
)
session.report_errors = report_errors
if report_errors is not None:
session.report_errors = report_errors
self.session = session
self.api_id = int(api_id)
self.api_hash = api_hash
@ -128,10 +135,6 @@ class TelegramBareClient:
# expensive operation.
self._exported_sessions = {}
# This member will process updates if enabled.
# One may change self.updates.enabled at any later point.
self.updates = UpdateState(self._loop)
# Used on connection - the user may modify these and reconnect
system = platform.uname()
self.device_model = device_model or system.system or 'Unknown'
@ -140,6 +143,12 @@ class TelegramBareClient:
self.lang_code = lang_code
self.system_lang_code = system_lang_code
self._state = None
self._sync_loading = False
self.update_handler = update_handler
self.unauthorized_handler = unauthorized_handler
self._last_update = datetime.now()
# Despite the state of the real connection, keep track of whether
# the user has explicitly called .connect() or .disconnect() here.
# This information is required by the read thread, who will be the
@ -147,90 +156,72 @@ class TelegramBareClient:
# doesn't explicitly call .disconnect(), thus telling it to stop
# retrying. The main thread, knowing there is a background thread
# attempting reconnection as soon as it happens, will just sleep.
self._user_connected = False
# Save whether the user is authorized here (a.k.a. logged in)
self._authorized = None # None = We don't know yet
# The first request must be in invokeWithLayer(initConnection(X)).
# See https://core.telegram.org/api/invoking#saving-client-info.
self._first_request = True
self._user_connected = Event(loop=self._loop)
self._authorized = False
self._shutdown = False
self._recv_loop = None
self._ping_loop = None
self._state_loop = None
self._idling = asyncio.Event()
self._reconnection_loop = None
# Default PingRequest delay
self._ping_delay = timedelta(minutes=1)
# Also have another delay for GetStateRequest.
#
# If the connection is kept alive for long without invoking any
# high level request the server simply stops sending updates.
# TODO maybe we can have ._last_request instead if any req works?
self._state_delay = timedelta(hours=1)
if isinstance(ping_delay, timedelta):
self._ping_delay = ping_delay.seconds
elif isinstance(ping_delay, (int, float)):
self._ping_delay = float(ping_delay)
else:
raise TypeError('Invalid timeout type', type(timeout))
# endregion
# region Connecting
async def connect(self, _sync_updates=True):
"""Connects to the Telegram servers, executing authentication if
required. Note that authenticating to the Telegram servers is
not the same as authenticating the desired user itself, which
may require a call (or several) to 'sign_in' for the first time.
Note that the optional parameters are meant for internal use.
If '_sync_updates', sync_updates() will be called and a
second thread will be started if necessary. Note that this
will FAIL if the client is not connected to the user's
native data center, raising a "UserMigrateError", and
calling .disconnect() in the process.
"""
__log__.info('Connecting to %s:%d...',
self.session.server_address, self.session.port)
def __del__(self):
self.disconnect()
async def connect(self):
try:
await self._sender.connect()
__log__.info('Connection success!')
# Connection was successful! Try syncing the update state
# UNLESS '_sync_updates' is False (we probably are in
# another data center and this would raise UserMigrateError)
# to also assert whether the user is logged in or not.
self._user_connected = True
if self._authorized is None and _sync_updates:
if not self._sender.is_connected():
await self._sender.connect()
if not self.session.auth_key:
try:
await self.sync_updates()
await self._set_connected_and_authorized()
self.session.auth_key, self.session.time_offset = \
await authenticator.do_authentication(self._sender.connection)
await self.session.save()
except BrokenAuthKeyError:
self._user_connected.clear()
return False
if TelegramBareClient._config is None:
TelegramBareClient._config = await self(self._wrap_init_connection(GetConfigRequest()))
if not self._authorized:
try:
self._state = await self(self._wrap_init_connection(GetStateRequest()))
self._authorized = True
except UnauthorizedError:
self._authorized = False
elif self._authorized:
await self._set_connected_and_authorized()
self.run_loops()
self._user_connected.set()
return True
except TypeNotFoundError as e:
# This is fine, probably layer migration
__log__.warning('Connection failed, got unexpected type with ID '
'%s. Migrating?', hex(e.invalid_constructor_id))
self.disconnect()
return await self.connect(_sync_updates=_sync_updates)
self.disconnect(False)
return await self.connect()
except AuthKeyError as e:
# As of late March 2018 there were two AUTH_KEY_DUPLICATED
# reports. Retrying with a clean auth_key should fix this.
__log__.warning('Auth key error %s. Clearing it and retrying.', e)
self.disconnect()
self.session.auth_key = None
self.session.save()
return self.connect(_sync_updates=_sync_updates)
if not self._authorized:
__log__.warning('Auth key error %s. Clearing it and retrying.', e)
self.disconnect(False)
self.session.auth_key = None
return self.connect()
else:
raise
except (RPCError, ConnectionError) as e:
# Probably errors from the previous session, ignore them
__log__.error('Connection failed due to %s', e)
self.disconnect()
self.disconnect(False)
return False
def is_connected(self):
@ -249,24 +240,11 @@ class TelegramBareClient:
query=query
))
def disconnect(self):
def disconnect(self, shutdown=True):
"""Disconnects from the Telegram server"""
__log__.info('Disconnecting...')
self._user_connected = False
self._sender.disconnect()
if self._recv_loop:
self._recv_loop.cancel()
self._recv_loop = None
if self._ping_loop:
self._ping_loop.cancel()
self._ping_loop = None
if self._state_loop:
self._state_loop.cancel()
self._state_loop = None
# TODO Shall we clear the _exported_sessions, or may be reused?
self._first_request = True # On reconnect it will be first again
self.session.set_update_state(0, self.updates.get_update_state(0))
self.session.close()
self._shutdown = shutdown
self._user_connected.clear()
self._sender.disconnect(clear_pendings=shutdown)
async def _reconnect(self, new_dc=None):
"""If 'new_dc' is not set, only a call to .connect() will be made
@ -277,32 +255,23 @@ class TelegramBareClient:
current data center, clears the auth key for the old DC, and
connects to the new data center.
"""
if new_dc is None:
# Assume we are disconnected due to some error, so connect again
try:
if self.is_connected():
__log__.info('Reconnection aborted: already connected')
return True
await self._reconnect_lock.acquire()
try:
# Another thread may have connected again, so check that first
if self.is_connected() and new_dc is None:
return True
__log__.info('Attempting reconnection...')
return await self.connect()
except ConnectionResetError as e:
__log__.warning('Reconnection failed due to %s', e)
return False
else:
# Since we're reconnecting possibly due to a UserMigrateError,
# we need to first know the Data Centers we can connect to. Do
# that before disconnecting.
dc = await self._get_dc(new_dc)
__log__.info('Reconnecting to new data center %s', dc)
if new_dc is not None:
dc = await self._get_dc(new_dc)
self.disconnect(False)
self.session.set_dc(dc.id, dc.ip_address, dc.port)
await self.session.save()
self.session.set_dc(dc.id, dc.ip_address, dc.port)
# auth_key's are associated with a server, which has now changed
# so it's not valid anymore. Set to None to force recreating it.
self.session.auth_key = None
self.session.save()
self.disconnect()
return await self.connect()
except (ConnectionResetError, TimeoutError):
return False
finally:
self._reconnect_lock.release()
def set_proxy(self, proxy):
"""Change the proxy used by the connections.
@ -348,7 +317,7 @@ class TelegramBareClient:
"""
# Thanks badoualy/kotlogram on /telegram/api/DefaultTelegramClient.kt
# for clearly showing how to export the authorization! ^^
session = self._exported_sessions.get(dc_id)
session = self._exported_sessions.get(dc_id, None)
if session:
export_auth = None # Already bound with the auth key
else:
@ -377,7 +346,7 @@ class TelegramBareClient:
timeout=self._sender.connection.get_timeout(),
loop=self._loop
)
await client.connect(_sync_updates=False)
await client.connect()
if isinstance(export_auth, ExportedAuthorization):
await client(ImportAuthorizationRequest(
id=export_auth.id, bytes=export_auth.bytes
@ -385,12 +354,11 @@ class TelegramBareClient:
elif export_auth is not None:
__log__.warning('Unknown export auth type %s', export_auth)
client._authorized = True # We exported the auth, so we got auth
return client
async def _get_cdn_client(self, cdn_redirect):
"""Similar to ._get_exported_client, but for CDNs"""
session = self._exported_sessions.get(cdn_redirect.dc_id)
session = self._exported_sessions.get(cdn_redirect.dc_id, None)
if not session:
dc = await self._get_dc(cdn_redirect.dc_id, cdn=True)
session = self.session.clone()
@ -410,8 +378,7 @@ class TelegramBareClient:
# We won't be calling GetConfigRequest because it's only called
# when needed by ._get_dc, and also it's static so it's likely
# set already. Avoid invoking non-CDN methods by not syncing updates.
await client.connect(_sync_updates=False)
client._authorized = self._authorized
await client.connect()
return client
# endregion
@ -461,100 +428,70 @@ class TelegramBareClient:
which = '{} requests ({})'.format(
len(request), [type(x).__name__ for x in request])
is_ping = any(isinstance(x, PingRequest) for x in request)
msg_ids = []
__log__.debug('Invoking %s', which)
call_receive = \
not self._idling.is_set() or self._reconnect_lock.locked()
try:
for retry in range(retries):
result = None
for sub_retry in range(retries):
msg_ids, result = await self._invoke(retry, request, ordered, msg_ids)
if msg_ids:
break
if not self.is_connected():
break
__log__.error('Subretry %d is failed' % sub_retry)
if result is None:
if not is_ping:
try:
pong = await self(PingRequest(int.from_bytes(os.urandom(8), 'big', signed=True)), retries=1)
if isinstance(pong, Pong):
__log__.error('Connection is live, but no answer on %d retry' % retry)
continue
except RuntimeError:
pass # continue to reconnect
if self.is_connected() and (retry + 1) % 2 == 0:
__log__.error('Force disconnect on %d retry' % retry)
self.disconnect(False)
self._sender.forget_pendings(msg_ids)
msg_ids = []
if not self.is_connected():
__log__.error('Pause before new retry on %d retry' % retry)
await asyncio.sleep(retry + 1, loop=self._loop)
else:
return result[0] if single else result
finally:
self._sender.forget_pendings(msg_ids)
for retry in range(retries):
result = await self._invoke(call_receive, retry, request,
ordered=ordered)
if result is not None:
return result[0] if single else result
log = __log__.info if retry == 0 else __log__.warning
log('Invoking %s failed %d times, connecting again and retrying',
which, retry + 1)
await asyncio.sleep(1)
if not self._reconnect_lock.locked():
with await self._reconnect_lock:
await self._reconnect()
raise RuntimeError('Number of retries reached 0 for {}.'.format(
which
))
raise RuntimeError('Number of retries is exceeded for {}.'.format(which))
# Let people use client.invoke(SomeRequest()) instead client(...)
invoke = __call__
async def _invoke(self, call_receive, retry, requests, ordered=False):
async def _invoke(self, retry, requests, ordered, msg_ids):
try:
# Ensure that we start with no previous errors (i.e. resending)
for x in requests:
x.rpc_error = None
if not msg_ids:
msg_ids = await self._sender.send(requests, ordered)
if not self.session.auth_key:
__log__.info('Need to generate new auth key before invoking')
self._first_request = True
self.session.auth_key, self.session.time_offset = \
await authenticator.do_authentication(self._sender.connection)
# Ensure that we start with no previous errors (i.e. resending)
for x in requests:
x.rpc_error = None
if self._first_request:
__log__.info('Initializing a new connection while invoking')
if len(requests) == 1:
requests = [self._wrap_init_connection(requests[0])]
else:
# We need a SINGLE request (like GetConfig) to init conn.
# Once that's done, the N original requests will be
# invoked.
TelegramBareClient._config = await self(
self._wrap_init_connection(GetConfigRequest())
)
await self._sender.send(requests, ordered=ordered)
if not call_receive:
await asyncio.wait(
list(map(lambda x: x.confirm_received.wait(), requests)),
timeout=self._sender.connection.get_timeout(),
loop=self._loop
)
if self._user_connected.is_set():
fut = asyncio.gather(*list(map(lambda x: x.confirm_received.wait(), requests)), loop=self._loop)
self._loop.call_later(self._sender.connection.get_timeout(), fut.cancel)
await fut
else:
while not all(x.confirm_received.is_set() for x in requests):
await self._sender.receive(update_state=self.updates)
except BrokenAuthKeyError:
__log__.error('Authorization key seems broken and was invalid!')
self.session.auth_key = None
except TypeNotFoundError as e:
# Only occurs when we call receive. May happen when
# we need to reconnect to another DC on login and
# Telegram somehow sends old objects (like configOld)
self._first_request = True
__log__.warning('Read unknown TLObject code ({}). '
'Setting again first_request flag.'
.format(hex(e.invalid_constructor_id)))
except TimeoutError:
__log__.warning('Invoking timed out') # We will just retry
await self._sender.receive(self._updates_handler)
except (TimeoutError, asyncio.CancelledError):
__log__.error('Timeout on %d retry' % retry)
except ConnectionResetError as e:
__log__.warning('Connection was reset while invoking')
if self._user_connected:
# Server disconnected us, __call__ will try reconnecting.
try:
self._sender.disconnect()
except:
pass
return None
else:
# User never called .connect(), so raise this error.
raise RuntimeError('Tried to invoke without .connect()') from e
# Clear the flag if we got this far
self._first_request = False
if self._shutdown:
raise
__log__.error('Connection reset on %d retry: %r' % (retry, e))
try:
raise next(x.rpc_error for x in requests if x.rpc_error)
@ -562,15 +499,32 @@ class TelegramBareClient:
if any(x.result is None for x in requests):
# "A container may only be accepted or
# rejected by the other party as a whole."
return None
return msg_ids, None
return [x.result for x in requests]
for req in requests:
if isinstance(req.result, TLObject) and req.result.SUBCLASS_OF_ID == Updates.SUBCLASS_OF_ID:
self._updates_handler(req.result, False, False)
if isinstance(req.result, (AffectedMessages, AffectedHistory)): # due to affect to pts
self._updates_handler(UpdateShort(req.result, None), False, False)
except (PhoneMigrateError, NetworkMigrateError,
UserMigrateError) as e:
return msg_ids, [x.result for x in requests]
except (PhoneMigrateError, NetworkMigrateError, UserMigrateError) as e:
if isinstance(e, (PhoneMigrateError, NetworkMigrateError)):
if self._authorized:
raise
else:
self.session.auth_key = None # Force creating new auth_key
__log__.error(
'DC error when invoking request, '
'attempting to reconnect at DC {}'.format(e.new_dc)
)
await self._reconnect(new_dc=e.new_dc)
return await self._invoke(call_receive, retry, requests)
self._sender.forget_pendings(msg_ids)
msg_ids = []
return msg_ids, None
except (ServerError, RpcCallFailError) as e:
# Telegram is having some issues, just retry
@ -582,7 +536,16 @@ class TelegramBareClient:
raise
await asyncio.sleep(e.seconds, loop=self._loop)
return None
return msg_ids, None
except UnauthorizedError:
if self._authorized:
__log__.error('Authorization has lost')
self._authorized = False
self.disconnect()
if self.unauthorized_handler:
await self.unauthorized_handler(self)
raise
# Some really basic functionality
@ -602,74 +565,216 @@ class TelegramBareClient:
# region Updates handling
async def sync_updates(self):
"""Synchronizes self.updates to their initial state. Will be
called automatically on connection if self.updates.enabled = True,
otherwise it should be called manually after enabling updates.
"""
self.updates.process(await self(GetStateRequest()))
async def _handle_update(self, update, seq_start, seq, date, do_get_diff, do_handlers, users=(), chats=()):
if isinstance(update, (UpdateNewChannelMessage, UpdateEditChannelMessage,
UpdateDeleteChannelMessages, UpdateChannelTooLong)):
# TODO: channel updates have their own pts sequences, so requires individual pts'es
return # ignore channel updates to keep pts in the main _state in the correct state
if hasattr(update, 'pts'):
new_pts = self._state.pts + getattr(update, 'pts_count', 0)
if new_pts < update.pts:
__log__.debug('Have got a hole between pts => waiting 0.5 sec')
await asyncio.sleep(0.5, loop=self._loop)
if new_pts < update.pts:
if do_get_diff and not self._sync_loading:
__log__.debug('The hole between pts has not disappeared => going to get differences')
self._sync_loading = True
asyncio.ensure_future(self._get_difference(), loop=self._loop)
return
if update.pts > self._state.pts:
self._state.pts = update.pts
elif getattr(update, 'pts_count', 0) > 0:
__log__.debug('Have got the duplicate update (basing on pts) => ignoring')
return
elif hasattr(update, 'qts'):
if self._state.qts + 1 < update.qts:
__log__.debug('Have got a hole between qts => waiting 0.5 sec')
await asyncio.sleep(0.5, loop=self._loop)
if self._state.qts + 1 < update.qts:
if do_get_diff and not self._sync_loading:
__log__.debug('The hole between qts has not disappeared => going to get differences')
self._sync_loading = True
asyncio.ensure_future(self._get_difference(), loop=self._loop)
return
if update.qts > self._state.qts:
self._state.qts = update.qts
else:
__log__.debug('Have got the duplicate update (basing on qts) => ignoring')
return
elif seq > 0:
if seq_start > self._state.seq + 1:
__log__.debug('Have got a hole between seq => waiting 0.5 sec')
await asyncio.sleep(0.5, loop=self._loop)
if seq_start > self._state.seq + 1:
if do_get_diff and not self._sync_loading:
__log__.debug('The hole between seq has not disappeared => going to get differences')
self._sync_loading = True
asyncio.ensure_future(self._get_difference(), loop=self._loop)
return
self._state.seq = seq
self._state.date = max(self._state.date, date)
# endregion
if do_handlers and self.update_handler:
asyncio.ensure_future(self.update_handler(self, update, users, chats), loop=self._loop)
# Constant read
async def _get_difference(self):
self._sync_loading = True
try:
difference = await self(GetDifferenceRequest(self._state.pts, self._state.date, self._state.qts))
if isinstance(difference, DifferenceEmpty):
__log__.debug('Have got DifferenceEmpty => just update seq and date')
self._state.seq = difference.seq
self._state.date = difference.date
return
if isinstance(difference, DifferenceTooLong):
__log__.debug('Have got DifferenceTooLong => update pts and try again')
self._state.pts = difference.pts
asyncio.ensure_future(self._get_difference(), loop=self._loop)
return
__log__.debug('Preparing updates from differences')
self._state = difference.intermediate_state \
if isinstance(difference, DifferenceSlice) else difference.state
messages = [UpdateNewMessage(msg, self._state.pts, 0) for msg in difference.new_messages]
self._updates_handler(
Updates(messages + difference.other_updates,
difference.users, difference.chats, self._state.date, self._state.seq),
False
)
if isinstance(difference, DifferenceSlice):
asyncio.ensure_future(self._get_difference(), loop=self._loop)
except ConnectionResetError: # it happens on unauth due to _get_difference is often on the background
pass
except Exception as e:
__log__.exception('Exception on _get_difference: %r', e)
finally:
self._sync_loading = False
# This is async so that the overrided version in TelegramClient can be
# async without problems.
async def _set_connected_and_authorized(self):
self._authorized = True
# TODO: Some of logic was moved from MtProtoSender and probably must be moved back.
def _updates_handler(self, updates, do_get_diff=True, do_handlers=True):
if do_get_diff:
self._last_update = datetime.now()
if isinstance(updates, NewSessionCreated):
self.session.salt = updates.server_salt
if self._state is None:
return False # not ready yet
if self._sync_loading and do_get_diff:
return False # ignore all if in sync except from difference (do_get_diff = False)
if isinstance(updates, (NewSessionCreated, UpdatesTooLong)):
if do_get_diff: # to prevent possible loops
__log__.debug('Have got %s => going to get differences', type(updates))
self._sync_loading = True
asyncio.ensure_future(self._get_difference(), loop=self._loop)
return False
seq = getattr(updates, 'seq', 0)
seq_start = getattr(updates, 'seq_start', seq)
date = getattr(updates, 'date', self._state.date)
if isinstance(updates, UpdateShort):
asyncio.ensure_future(
self._handle_update(updates.update, seq_start, seq, date, do_get_diff, do_handlers),
loop=self._loop
)
return True
if isinstance(updates, UpdateShortSentMessage):
asyncio.ensure_future(self._handle_update(
UpdateNewMessage(MessageEmpty(updates.id), updates.pts, updates.pts_count),
seq_start, seq, date, do_get_diff, do_handlers
), loop=self._loop)
return True
if isinstance(updates, (UpdateShortMessage, UpdateShortChatMessage)):
from_id = getattr(updates, 'from_id', self.session.user_id)
to_id = updates.user_id if isinstance(updates, UpdateShortMessage) else updates.chat_id
if not updates.out:
from_id, to_id = to_id, from_id
to_id = PeerUser(to_id) if isinstance(updates, UpdateShortMessage) else PeerChat(to_id)
message = Message(
id=updates.id, to_id=to_id, date=updates.date, message=updates.message, out=updates.out,
mentioned=updates.mentioned, media_unread=updates.media_unread, silent=updates.silent,
from_id=from_id, fwd_from=updates.fwd_from, via_bot_id=updates.via_bot_id,
reply_to_msg_id=updates.reply_to_msg_id, entities=updates.entities
)
asyncio.ensure_future(self._handle_update(
UpdateNewMessage(message, updates.pts, updates.pts_count),
seq_start, seq, date, do_get_diff, do_handlers
), loop=self._loop)
return True
if isinstance(updates, (Updates, UpdatesCombined)):
for upd in updates.updates:
asyncio.ensure_future(
self._handle_update(upd, seq_start, seq, date, do_get_diff, do_handlers,
updates.users, updates.chats),
loop=self._loop
)
return True
if do_get_diff: # to prevent possible loops
__log__.debug('Have got unsupported type of updates: %s => going to get differences', type(updates))
self._sync_loading = True
asyncio.ensure_future(self._get_difference(), loop=self._loop)
return False
def run_loops(self):
if self._recv_loop is None:
self._recv_loop = asyncio.ensure_future(self._recv_loop_impl(), loop=self._loop)
if self._ping_loop is None:
self._ping_loop = asyncio.ensure_future(self._ping_loop_impl(), loop=self._loop)
if self._state_loop is None:
self._state_loop = asyncio.ensure_future(self._state_loop_impl(), loop=self._loop)
async def _ping_loop_impl(self):
while self._user_connected:
await self(PingRequest(int.from_bytes(os.urandom(8), 'big', signed=True)))
await asyncio.sleep(self._ping_delay.seconds, loop=self._loop)
while True:
if self._shutdown:
break
try:
await self._user_connected.wait()
await self(PingRequest(int.from_bytes(os.urandom(8), 'big', signed=True)))
await asyncio.sleep(self._ping_delay, loop=self._loop)
except RuntimeError:
pass # Can be not happy due to connection problems
except asyncio.CancelledError:
break
except:
self._ping_loop = None
raise
self._ping_loop = None
async def _state_loop_impl(self):
while self._user_connected:
await asyncio.sleep(self._state_delay.seconds, loop=self._loop)
await self._sender.send(GetStateRequest())
async def _recv_loop_impl(self):
__log__.info('Starting to wait for items from the network')
self._idling.set()
need_reconnect = False
while self._user_connected:
timeout = 1
while True:
if self._shutdown:
break
try:
if need_reconnect:
__log__.info('Attempting reconnection from read loop')
need_reconnect = False
with await self._reconnect_lock:
while self._user_connected and not await self._reconnect():
# Retry forever, this is instant messaging
await asyncio.sleep(0.1, loop=self._loop)
# Telegram seems to kick us every 1024 items received
# from the network not considering things like bad salt.
# We must execute some *high level* request (that's not
# a ping) if we want to receive updates again.
# TODO Test if getDifference works too (better alternative)
await self._sender.send(GetStateRequest())
__log__.debug('Receiving items from the network...')
await self._sender.receive(update_state=self.updates)
if self._user_connected.is_set():
if self._authorized and datetime.now() - self._last_update > timedelta(minutes=15):
__log__.debug('No updates for 15 minutes => going to get differences')
self._last_update = datetime.now()
self._sync_loading = True
asyncio.ensure_future(self._get_difference(), loop=self._loop)
await self._sender.receive(self._updates_handler)
else:
if await self._reconnect():
__log__.info('Connection has established')
timeout = 1
else:
await asyncio.sleep(timeout, loop=self._loop)
timeout = min(timeout * 2, MAX_TIMEOUT)
except TimeoutError:
# No problem.
__log__.debug('Receiving items from the network timed out')
except ConnectionError:
need_reconnect = True
__log__.error('Connection was reset while receiving items')
await asyncio.sleep(1, loop=self._loop)
except:
self._idling.clear()
raise
self._idling.clear()
__log__.info('Connection closed by the user, not reading anymore')
pass
except ConnectionResetError as error:
__log__.info('Connection reset error in recv loop: %r' % error)
self._user_connected.clear()
except asyncio.CancelledError:
self.disconnect()
break
except Exception as error:
# Unknown exception, pass it to the main thread
__log__.exception('[ERROR: %r] on the read loop, please report', error)
self._recv_loop = None
if self._shutdown and self._ping_loop:
self._ping_loop.cancel()
# endregion

View File

@ -510,7 +510,7 @@ class TelegramClient(TelegramBareClient):
return False
self.disconnect()
self.session.delete()
await self.session.delete()
self._authorized = False
return True
@ -1805,7 +1805,7 @@ class TelegramClient(TelegramBareClient):
to_cache = utils.get_input_photo(msg.media.photo)
else:
to_cache = utils.get_input_document(msg.media.document)
self.session.cache_file(md5, size, to_cache)
await self.session.cache_file(md5, size, to_cache)
return msg
@ -1849,7 +1849,7 @@ class TelegramClient(TelegramBareClient):
input_photo = utils.get_input_photo((await self(UploadMediaRequest(
entity, media=InputMediaUploadedPhoto(fh)
))).photo)
self.session.cache_file(fh.md5, fh.size, input_photo)
await self.session.cache_file(fh.md5, fh.size, input_photo)
fh = input_photo
if captions:
@ -1957,7 +1957,7 @@ class TelegramClient(TelegramBareClient):
file = stream.read()
hash_md5.update(file)
if use_cache:
cached = self.session.get_file(
cached = await self.session.get_file(
hash_md5.digest(), file_size, cls=use_cache
)
if cached:
@ -2122,7 +2122,7 @@ class TelegramClient(TelegramBareClient):
media, file, date, progress_callback
)
elif isinstance(media, MessageMediaContact):
return await self._download_contact(
return self._download_contact(
media, file
)
@ -2472,7 +2472,7 @@ class TelegramClient(TelegramBareClient):
be passed instead.
"""
self.updates.handler = self._on_handler
self.update_handler = self._on_handler
if isinstance(event, type):
event = event()
elif not event:
@ -2555,7 +2555,7 @@ class TelegramClient(TelegramBareClient):
# infinite loop here (so check against old pts to stop)
break
self.updates.process(Updates(
self._updates_handler(Updates(
users=d.users,
chats=d.chats,
date=state.date,
@ -2573,10 +2573,6 @@ class TelegramClient(TelegramBareClient):
# region Small utilities to make users' life easier
async def _set_connected_and_authorized(self):
await super()._set_connected_and_authorized()
await self._check_events_pending_resolve()
async def get_entity(self, entity):
"""
Turns the given entity into a valid Telegram user or chat.
@ -2694,7 +2690,7 @@ class TelegramClient(TelegramBareClient):
try:
# Nobody with this username, maybe it's an exact name/title
return await self.get_entity(
self.session.get_input_entity(string))
await self.session.get_input_entity(string))
except ValueError:
pass
@ -2729,7 +2725,7 @@ class TelegramClient(TelegramBareClient):
try:
# First try to get the entity from cache, otherwise figure it out
return self.session.get_input_entity(peer)
return await self.session.get_input_entity(peer)
except ValueError:
pass

View File

@ -192,3 +192,6 @@ class TLObject:
@classmethod
def from_reader(cls, reader):
return TLObject()
def __repr__(self):
return self.__str__()

View File

@ -1,65 +0,0 @@
import asyncio
import itertools
import logging
from datetime import datetime
from . import utils
from .tl import types as tl
__log__ = logging.getLogger(__name__)
class UpdateState:
"""
Used to hold the current state of processed updates.
To retrieve an update, :meth:`poll` should be called.
"""
WORKER_POLL_TIMEOUT = 5.0 # Avoid waiting forever on the workers
def __init__(self, loop=None):
self.handler = None
self._loop = loop if loop else asyncio.get_event_loop()
# https://core.telegram.org/api/updates
self._state = tl.updates.State(0, 0, datetime.now(), 0, 0)
def handle_update(self, update):
if self.handler:
asyncio.ensure_future(self.handler(update), loop=self._loop)
def get_update_state(self, entity_id):
"""Gets the updates.State corresponding to the given entity or 0."""
return self._state
def process(self, update):
"""Processes an update object. This method is normally called by
the library itself.
"""
if isinstance(update, tl.updates.State):
__log__.debug('Saved new updates state')
self._state = update
return # Nothing else to be done
if hasattr(update, 'pts'):
self._state.pts = update.pts
# After running the script for over an hour and receiving over
# 1000 updates, the only duplicates received were users going
# online or offline. We can trust the server until new reports.
# This should only be used as read-only.
if isinstance(update, tl.UpdateShort):
update.update._entities = {}
self.handle_update(update.update)
# Expand "Updates" into "Update", and pass these to callbacks.
# Since .users and .chats have already been processed, we
# don't need to care about those either.
elif isinstance(update, (tl.Updates, tl.UpdatesCombined)):
entities = {utils.get_peer_id(x): x for x in
itertools.chain(update.users, update.chats)}
for u in update.updates:
u._entities = entities
self.handle_update(u)
# TODO Handle "tl.UpdatesTooLong"
else:
update._entities = {}
self.handle_update(update)