Merge branch 'asyncio' into asyncio-upstream

* asyncio:
  Not need to save (salt is out of DB)
  Very rare exception in the case of reconnect
  updates_handler is out from MtProtoSender to gc works properly; unauth_handler log format fix
  Memory leaks fix
  Pretty format of TLObject's
  More accurate clear pendings
  Another attempt to prevent duplicates
  Handle updates and other refactoring
  SocketClosed exception
  Refactoring of TcpClient
  Socket OSError logging
  More aggressive catching network errors
  No route to host catched + other errno to reconnect

# Conflicts (resolved):
#	telethon/extensions/tcp_client.py
#	telethon/network/mtproto_sender.py
#	telethon/telegram_bare_client.py
#	telethon/tl/session.py
This commit is contained in:
Andrey Egorov 2018-06-14 14:34:08 +03:00
commit 43a0226b33
9 changed files with 550 additions and 484 deletions

View File

@ -6,32 +6,33 @@ import asyncio
import errno import errno
import logging import logging
import socket import socket
import time
from datetime import timedelta from datetime import timedelta
from io import BytesIO, BufferedWriter from io import BytesIO, BufferedWriter
MAX_TIMEOUT = 15 # in seconds
CONN_RESET_ERRNOS = { CONN_RESET_ERRNOS = {
errno.EBADF, errno.ENOTSOCK, errno.ENETUNREACH, errno.EBADF, errno.ENOTSOCK, errno.ENETUNREACH,
errno.EINVAL, errno.ENOTCONN errno.EINVAL, errno.ENOTCONN, errno.EHOSTUNREACH,
errno.ECONNREFUSED, errno.ECONNRESET, errno.ECONNABORTED,
errno.ENETDOWN, errno.ENETRESET, errno.ECONNABORTED,
errno.EHOSTDOWN, errno.EPIPE, errno.ESHUTDOWN
} }
# catched: EHOSTUNREACH, ECONNREFUSED, ECONNRESET, ENETUNREACH
# ConnectionError: EPIPE, ESHUTDOWN, ECONNABORTED, ECONNREFUSED, ECONNRESET
try: try:
import socks import socks
except ImportError: except ImportError:
socks = None socks = None
MAX_TIMEOUT = 15 # in seconds
CONN_RESET_ERRNOS = {
errno.EBADF, errno.ENOTSOCK, errno.ENETUNREACH,
errno.EINVAL, errno.ENOTCONN
}
__log__ = logging.getLogger(__name__) __log__ = logging.getLogger(__name__)
class TcpClient: class TcpClient:
"""A simple TCP client to ease the work with sockets and proxies.""" """A simple TCP client to ease the work with sockets and proxies."""
class SocketClosed(ConnectionError):
pass
def __init__(self, proxy=None, timeout=timedelta(seconds=5), loop=None): def __init__(self, proxy=None, timeout=timedelta(seconds=5), loop=None):
""" """
Initializes the TCP client. Initializes the TCP client.
@ -42,6 +43,8 @@ class TcpClient:
self.proxy = proxy self.proxy = proxy
self._socket = None self._socket = None
self._loop = loop if loop else asyncio.get_event_loop() self._loop = loop if loop else asyncio.get_event_loop()
self._closed = asyncio.Event(loop=self._loop)
self._closed.set()
if isinstance(timeout, timedelta): if isinstance(timeout, timedelta):
self.timeout = timeout.seconds self.timeout = timeout.seconds
@ -76,41 +79,28 @@ class TcpClient:
else: else:
mode, address = socket.AF_INET, (ip, port) mode, address = socket.AF_INET, (ip, port)
timeout = 1 try:
while True: if not self._socket:
try: self._recreate_socket(mode)
if not self._socket:
self._recreate_socket(mode)
await self._loop.sock_connect(self._socket, address) await asyncio.wait_for(
break # Successful connection, stop retrying to connect self._loop.sock_connect(self._socket, address),
except ConnectionError: timeout=self.timeout,
self._socket = None loop=self._loop
await asyncio.sleep(timeout) )
timeout = min(timeout * 2, MAX_TIMEOUT)
except OSError as e: self._closed.clear()
__log__.info('OSError "%s" raised while connecting', e) except asyncio.TimeoutError as e:
# Stop retrying to connect if proxy connection error occurred raise TimeoutError() from e
if socks and isinstance(e, socks.ProxyConnectionError): except OSError as e:
raise if e.errno in CONN_RESET_ERRNOS:
# There are some errors that we know how to handle, and self._raise_connection_reset(e)
# the loop will allow us to retry else:
if e.errno in (errno.EBADF, errno.ENOTSOCK, errno.EINVAL, raise
errno.ECONNREFUSED, # Windows-specific follow
getattr(errno, 'WSAEACCES', None)):
# Bad file descriptor, i.e. socket was closed, set it
# to none to recreate it on the next iteration
self._socket = None
await asyncio.sleep(timeout)
timeout *= 2
if timeout > MAX_TIMEOUT:
raise
else:
raise
def _get_connected(self): def _get_connected(self):
"""Determines whether the client is connected or not.""" """Determines whether the client is connected or not."""
return self._socket is not None and self._socket.fileno() >= 0 return not self._closed.is_set()
connected = property(fget=_get_connected) connected = property(fget=_get_connected)
@ -118,12 +108,29 @@ class TcpClient:
"""Closes the connection.""" """Closes the connection."""
try: try:
if self._socket is not None: if self._socket is not None:
self._socket.shutdown(socket.SHUT_RDWR) if self.connected:
self._socket.shutdown(socket.SHUT_RDWR)
self._socket.close() self._socket.close()
except OSError: except OSError:
pass # Ignore ENOTCONN, EBADF, and any other error when closing pass # Ignore ENOTCONN, EBADF, and any other error when closing
finally: finally:
self._socket = None self._socket = None
self._closed.set()
async def _wait_close(self, coro):
done, running = await asyncio.wait(
[coro, self._closed.wait()],
timeout=self.timeout,
return_when=asyncio.FIRST_COMPLETED,
loop=self._loop
)
for r in running:
r.cancel()
if not self.connected:
raise self.SocketClosed()
if not done:
raise TimeoutError()
return done.pop().result()
async def write(self, data): async def write(self, data):
""" """
@ -131,21 +138,12 @@ class TcpClient:
:param data: the data to send. :param data: the data to send.
""" """
if self._socket is None: if not self.connected:
self._raise_connection_reset(None) raise ConnectionResetError('No connection')
try: try:
await asyncio.wait_for( await self._wait_close(self.sock_sendall(data))
self.sock_sendall(data), except self.SocketClosed:
timeout=self.timeout, raise ConnectionResetError('Socket has closed')
loop=self._loop
)
except asyncio.TimeoutError as e:
__log__.debug('socket.timeout "%s" while writing data', e)
raise TimeoutError() from e
except ConnectionError as e:
__log__.info('ConnectionError "%s" while writing data', e)
self._raise_connection_reset(e)
except OSError as e: except OSError as e:
__log__.info('OSError "%s" while writing data', e) __log__.info('OSError "%s" while writing data', e)
if e.errno in CONN_RESET_ERRNOS: if e.errno in CONN_RESET_ERRNOS:
@ -160,21 +158,15 @@ class TcpClient:
:param size: the size of the block to be read. :param size: the size of the block to be read.
:return: the read data with len(data) == size. :return: the read data with len(data) == size.
""" """
if self._socket is None:
self._raise_connection_reset(None)
with BufferedWriter(BytesIO(), buffer_size=size) as buffer: with BufferedWriter(BytesIO(), buffer_size=size) as buffer:
bytes_left = size bytes_left = size
partial = b''
while bytes_left != 0: while bytes_left != 0:
if not self.connected:
raise ConnectionResetError('No connection')
try: try:
if self._socket is None: partial = await self._wait_close(self.sock_recv(bytes_left))
self._raise_connection_reset() except TimeoutError as e:
partial = await asyncio.wait_for(
self.sock_recv(bytes_left),
timeout=self.timeout,
loop=self._loop
)
except asyncio.TimeoutError as e:
# These are somewhat common if the server has nothing # These are somewhat common if the server has nothing
# to send to us, so use a lower logging priority. # to send to us, so use a lower logging priority.
if bytes_left < size: if bytes_left < size:
@ -187,10 +179,9 @@ class TcpClient:
'socket.timeout "%s" while reading data', e 'socket.timeout "%s" while reading data', e
) )
raise TimeoutError() from e raise
except ConnectionError as e: except self.SocketClosed:
__log__.info('ConnectionError "%s" while reading data', e) raise ConnectionResetError('Socket has closed while reading data')
self._raise_connection_reset(e)
except OSError as e: except OSError as e:
if e.errno != errno.EBADF: if e.errno != errno.EBADF:
# Ignore bad file descriptor while closing # Ignore bad file descriptor while closing
@ -202,7 +193,7 @@ class TcpClient:
raise raise
if len(partial) == 0: if len(partial) == 0:
self._raise_connection_reset(None) self._raise_connection_reset('No data on read')
buffer.write(partial) buffer.write(partial)
bytes_left -= len(partial) bytes_left -= len(partial)
@ -211,10 +202,12 @@ class TcpClient:
buffer.flush() buffer.flush()
return buffer.raw.getvalue() return buffer.raw.getvalue()
def _raise_connection_reset(self, original): def _raise_connection_reset(self, error):
"""Disconnects the client and raises ConnectionResetError.""" description = error if isinstance(error, str) else str(error)
if isinstance(error, str):
error = Exception(error)
self.close() # Connection reset -> flag as socket closed self.close() # Connection reset -> flag as socket closed
raise ConnectionResetError('The server has closed the connection.') from original raise ConnectionResetError(description) from error
# due to new https://github.com/python/cpython/pull/4386 # due to new https://github.com/python/cpython/pull/4386
def sock_recv(self, n): def sock_recv(self, n):
@ -225,7 +218,7 @@ class TcpClient:
def _sock_recv(self, fut, registered_fd, n): def _sock_recv(self, fut, registered_fd, n):
if registered_fd is not None: if registered_fd is not None:
self._loop.remove_reader(registered_fd) self._loop.remove_reader(registered_fd)
if fut.cancelled(): if fut.cancelled() or self._socket is None:
return return
try: try:
@ -249,7 +242,7 @@ class TcpClient:
def _sock_sendall(self, fut, registered_fd, data): def _sock_sendall(self, fut, registered_fd, data):
if registered_fd: if registered_fd:
self._loop.remove_writer(registered_fd) self._loop.remove_writer(registered_fd)
if fut.cancelled(): if fut.cancelled() or self._socket is None:
return return
try: try:

View File

@ -14,11 +14,11 @@ from ..errors import (
from ..extensions import BinaryReader from ..extensions import BinaryReader
from ..tl import TLMessage, MessageContainer, GzipPacked from ..tl import TLMessage, MessageContainer, GzipPacked
from ..tl.all_tlobjects import tlobjects from ..tl.all_tlobjects import tlobjects
from ..tl.functions import InvokeAfterMsgRequest
from ..tl.functions.auth import LogOutRequest from ..tl.functions.auth import LogOutRequest
from ..tl.types import ( from ..tl.types import (
MsgsAck, Pong, BadServerSalt, BadMsgNotification, FutureSalts, MsgsAck, Pong, BadServerSalt, BadMsgNotification, FutureSalts,
MsgNewDetailedInfo, NewSessionCreated, MsgDetailedInfo MsgNewDetailedInfo, MsgDetailedInfo, MsgsStateReq, MsgResendReq,
MsgsAllInfo, MsgsStateInfo, RpcError
) )
__log__ = logging.getLogger(__name__) __log__ = logging.getLogger(__name__)
@ -56,7 +56,8 @@ class MtProtoSender:
# receiving other request from the main thread (e.g. an update arrives # receiving other request from the main thread (e.g. an update arrives
# and we need to process it) we must ensure that only one is calling # and we need to process it) we must ensure that only one is calling
# receive at a given moment, since the receive step is fragile. # receive at a given moment, since the receive step is fragile.
self._recv_lock = asyncio.Lock() self._read_lock = asyncio.Lock(loop=self._loop)
self._write_lock = asyncio.Lock(loop=self._loop)
# Requests (as msg_id: Message) sent waiting to be received # Requests (as msg_id: Message) sent waiting to be received
self._pending_receive = {} self._pending_receive = {}
@ -73,11 +74,12 @@ class MtProtoSender:
""" """
return self.connection.is_connected() return self.connection.is_connected()
def disconnect(self): def disconnect(self, clear_pendings=True):
"""Disconnects from the server.""" """Disconnects from the server."""
__log__.info('Disconnecting MtProtoSender...') __log__.info('Disconnecting MtProtoSender...')
self.connection.close() self.connection.close()
self._clear_all_pending() if clear_pendings:
self._clear_all_pending()
# region Send and receive # region Send and receive
@ -90,6 +92,7 @@ class MtProtoSender:
:param ordered: whether the requests should be invoked in the :param ordered: whether the requests should be invoked in the
order in which they appear or they can be executed order in which they appear or they can be executed
in arbitrary order in the server. in arbitrary order in the server.
:return: a list of msg_ids which are correspond to sent requests.
""" """
if not utils.is_list_like(requests): if not utils.is_list_like(requests):
requests = (requests,) requests = (requests,)
@ -111,6 +114,7 @@ class MtProtoSender:
messages = [TLMessage(self.session, r) for r in requests] messages = [TLMessage(self.session, r) for r in requests]
self._pending_receive.update({m.msg_id: m for m in messages}) self._pending_receive.update({m.msg_id: m for m in messages})
msg_ids = [m.msg_id for m in messages]
__log__.debug('Sending requests with IDs: %s', ', '.join( __log__.debug('Sending requests with IDs: %s', ', '.join(
'{}: {}'.format(m.request.__class__.__name__, m.msg_id) '{}: {}'.format(m.request.__class__.__name__, m.msg_id)
@ -128,12 +132,18 @@ class MtProtoSender:
m.container_msg_id = message.msg_id m.container_msg_id = message.msg_id
await self._send_message(message) await self._send_message(message)
return msg_ids
def forget_pendings(self, msg_ids):
for msg_id in msg_ids:
if msg_id in self._pending_receive:
del self._pending_receive[msg_id]
async def _send_acknowledge(self, msg_id): async def _send_acknowledge(self, msg_id):
"""Sends a message acknowledge for the given msg_id.""" """Sends a message acknowledge for the given msg_id."""
await self._send_message(TLMessage(self.session, MsgsAck([msg_id]))) await self._send_message(TLMessage(self.session, MsgsAck([msg_id])))
async def receive(self, update_state): async def receive(self, updates_handler):
""" """
Receives a single message from the connected endpoint. Receives a single message from the connected endpoint.
@ -144,21 +154,13 @@ class MtProtoSender:
Any unhandled object (likely updates) will be passed to Any unhandled object (likely updates) will be passed to
update_state.process(TLObject). update_state.process(TLObject).
:param update_state: :param updates_handler:
the UpdateState that will process all the received the handler that will process all the received
Update and Updates objects. Update and Updates objects.
""" """
if self._recv_lock.locked(): await self._read_lock.acquire()
with await self._recv_lock:
# Don't busy wait, acquire it but return because there's
# already a receive running and we don't want another one.
# It would lock until Telegram sent another update even if
# the current receive already received the expected response.
return
try: try:
with await self._recv_lock: body = await self.connection.recv()
body = await self.connection.recv()
except (BufferError, InvalidChecksumError): except (BufferError, InvalidChecksumError):
# TODO BufferError, we should spot the cause... # TODO BufferError, we should spot the cause...
# "No more bytes left"; something wrong happened, clear # "No more bytes left"; something wrong happened, clear
@ -172,11 +174,12 @@ class MtProtoSender:
len(self._pending_receive)) len(self._pending_receive))
self._clear_all_pending() self._clear_all_pending()
return return
finally:
self._read_lock.release()
message, remote_msg_id, remote_seq = self._decode_msg(body) message, remote_msg_id, remote_seq = self._decode_msg(body)
with BinaryReader(message) as reader: with BinaryReader(message) as reader:
await self._process_msg(remote_msg_id, remote_seq, reader, update_state) await self._process_msg(remote_msg_id, remote_seq, reader, updates_handler)
await self._send_acknowledge(remote_msg_id)
# endregion # endregion
@ -188,7 +191,11 @@ class MtProtoSender:
:param message: the TLMessage to be sent. :param message: the TLMessage to be sent.
""" """
await self.connection.send(helpers.pack_message(self.session, message)) await self._write_lock.acquire()
try:
await self.connection.send(helpers.pack_message(self.session, message))
finally:
self._write_lock.release()
def _decode_msg(self, body): def _decode_msg(self, body):
""" """
@ -206,14 +213,14 @@ class MtProtoSender:
with BinaryReader(body) as reader: with BinaryReader(body) as reader:
return helpers.unpack_message(self.session, reader) return helpers.unpack_message(self.session, reader)
async def _process_msg(self, msg_id, sequence, reader, state): async def _process_msg(self, msg_id, sequence, reader, updates_handler):
""" """
Processes the message read from the network inside reader. Processes the message read from the network inside reader.
:param msg_id: the ID of the message. :param msg_id: the ID of the message.
:param sequence: the sequence of the message. :param sequence: the sequence of the message.
:param reader: the BinaryReader that contains the message. :param reader: the BinaryReader that contains the message.
:param state: the current UpdateState. :param updates_handler: the handler to process Update and Updates objects.
:return: true if the message was handled correctly, false otherwise. :return: true if the message was handled correctly, false otherwise.
""" """
# TODO Check salt, session_id and sequence_number # TODO Check salt, session_id and sequence_number
@ -224,15 +231,16 @@ class MtProtoSender:
# These are a bit of special case, not yet generated by the code gen # These are a bit of special case, not yet generated by the code gen
if code == 0xf35c6d01: # rpc_result, (response of an RPC call) if code == 0xf35c6d01: # rpc_result, (response of an RPC call)
__log__.debug('Processing Remote Procedure Call result') __log__.debug('Processing Remote Procedure Call result')
await self._send_acknowledge(msg_id)
return await self._handle_rpc_result(msg_id, sequence, reader) return await self._handle_rpc_result(msg_id, sequence, reader)
if code == MessageContainer.CONSTRUCTOR_ID: if code == MessageContainer.CONSTRUCTOR_ID:
__log__.debug('Processing container result') __log__.debug('Processing container result')
return await self._handle_container(msg_id, sequence, reader, state) return await self._handle_container(msg_id, sequence, reader, updates_handler)
if code == GzipPacked.CONSTRUCTOR_ID: if code == GzipPacked.CONSTRUCTOR_ID:
__log__.debug('Processing gzipped result') __log__.debug('Processing gzipped result')
return await self._handle_gzip_packed(msg_id, sequence, reader, state) return await self._handle_gzip_packed(msg_id, sequence, reader, updates_handler)
if code not in tlobjects: if code not in tlobjects:
__log__.warning( __log__.warning(
@ -250,6 +258,14 @@ class MtProtoSender:
if isinstance(obj, BadServerSalt): if isinstance(obj, BadServerSalt):
return await self._handle_bad_server_salt(msg_id, sequence, obj) return await self._handle_bad_server_salt(msg_id, sequence, obj)
if isinstance(obj, (MsgsStateReq, MsgResendReq)):
# just answer we don't know anything
return await self._handle_msgs_state_forgotten(msg_id, sequence, obj)
if isinstance(obj, MsgsAllInfo):
# not interesting now
return True
if isinstance(obj, BadMsgNotification): if isinstance(obj, BadMsgNotification):
return await self._handle_bad_msg_notification(msg_id, sequence, obj) return await self._handle_bad_msg_notification(msg_id, sequence, obj)
@ -259,11 +275,8 @@ class MtProtoSender:
if isinstance(obj, MsgNewDetailedInfo): if isinstance(obj, MsgNewDetailedInfo):
return await self._handle_msg_new_detailed_info(msg_id, sequence, obj) return await self._handle_msg_new_detailed_info(msg_id, sequence, obj)
if isinstance(obj, NewSessionCreated):
return await self._handle_new_session_created(msg_id, sequence, obj)
if isinstance(obj, MsgsAck): # may handle the request we wanted if isinstance(obj, MsgsAck): # may handle the request we wanted
# Ignore every ack request *unless* when logging out, when it's # Ignore every ack request *unless* when logging out,
# when it seems to only make sense. We also need to set a non-None # when it seems to only make sense. We also need to set a non-None
# result since Telegram doesn't send the response for these. # result since Telegram doesn't send the response for these.
for msg_id in obj.msg_ids: for msg_id in obj.msg_ids:
@ -284,8 +297,9 @@ class MtProtoSender:
# If the object isn't any of the above, then it should be an Update. # If the object isn't any of the above, then it should be an Update.
self.session.process_entities(obj) self.session.process_entities(obj)
if state: await self._send_acknowledge(msg_id)
state.process(obj) if updates_handler:
updates_handler(obj)
return True return True
@ -372,22 +386,24 @@ class MtProtoSender:
return True return True
async def _handle_container(self, msg_id, sequence, reader, state): async def _handle_container(self, msg_id, sequence, reader, updates_handler):
""" """
Handles a MessageContainer response. Handles a MessageContainer response.
:param msg_id: the ID of the message. :param msg_id: the ID of the message.
:param sequence: the sequence of the message. :param sequence: the sequence of the message.
:param reader: the reader containing the MessageContainer. :param reader: the reader containing the MessageContainer.
:param updates_handler: handler to handle Update and Updates objects.
:return: true, as it always succeeds. :return: true, as it always succeeds.
""" """
__log__.debug('Handling container')
for inner_msg_id, _, inner_len in MessageContainer.iter_read(reader): for inner_msg_id, _, inner_len in MessageContainer.iter_read(reader):
begin_position = reader.tell_position() begin_position = reader.tell_position()
# Note that this code is IMPORTANT for skipping RPC results of # Note that this code is IMPORTANT for skipping RPC results of
# lost requests (i.e., ones from the previous connection session) # lost requests (i.e., ones from the previous connection session)
try: try:
if not await self._process_msg(inner_msg_id, sequence, reader, state): if not await self._process_msg(inner_msg_id, sequence, reader, updates_handler):
reader.set_position(begin_position + inner_len) reader.set_position(begin_position + inner_len)
except: except:
# If any error is raised, something went wrong; skip the packet # If any error is raised, something went wrong; skip the packet
@ -406,7 +422,6 @@ class MtProtoSender:
:return: true, as it always succeeds. :return: true, as it always succeeds.
""" """
self.session.salt = bad_salt.new_server_salt self.session.salt = bad_salt.new_server_salt
self.session.save()
# "the bad_server_salt response is received with the # "the bad_server_salt response is received with the
# correct salt, and the message is to be re-sent with it" # correct salt, and the message is to be re-sent with it"
@ -414,6 +429,10 @@ class MtProtoSender:
return True return True
async def _handle_msgs_state_forgotten(self, msg_id, sequence, req):
await self._send_message(TLMessage(self.session, MsgsStateInfo(msg_id, chr(1) * len(req.msg_ids))))
return True
async def _handle_bad_msg_notification(self, msg_id, sequence, bad_msg): async def _handle_bad_msg_notification(self, msg_id, sequence, bad_msg):
""" """
Handles a BadMessageError response. Handles a BadMessageError response.
@ -476,19 +495,6 @@ class MtProtoSender:
await self._send_acknowledge(msg_new.answer_msg_id) await self._send_acknowledge(msg_new.answer_msg_id)
return True return True
async def _handle_new_session_created(self, msg_id, sequence, new_session):
"""
Handles a NewSessionCreated response.
:param msg_id: the ID of the message.
:param sequence: the sequence of the message.
:param reader: the reader containing the NewSessionCreated.
:return: true, as it always succeeds.
"""
self.session.salt = new_session.server_salt
# TODO https://goo.gl/LMyN7A
return True
async def _handle_rpc_result(self, msg_id, sequence, reader): async def _handle_rpc_result(self, msg_id, sequence, reader):
""" """
Handles a RPCResult response. Handles a RPCResult response.
@ -507,7 +513,7 @@ class MtProtoSender:
__log__.debug('Received response for request with ID %d', request_id) __log__.debug('Received response for request with ID %d', request_id)
request = self._pop_request(request_id) request = self._pop_request(request_id)
if inner_code == 0x2144ca19: # RPC Error if inner_code == RpcError.CONSTRUCTOR_ID: # RPC Error
reader.seek(4) reader.seek(4)
if self.session.report_errors and request: if self.session.report_errors and request:
error = rpc_message_to_error( error = rpc_message_to_error(
@ -530,6 +536,7 @@ class MtProtoSender:
return True # All contents were read okay return True # All contents were read okay
elif request: elif request:
__log__.debug('Reading request response')
if inner_code == GzipPacked.CONSTRUCTOR_ID: if inner_code == GzipPacked.CONSTRUCTOR_ID:
with BinaryReader(GzipPacked.read(reader)) as compressed_reader: with BinaryReader(GzipPacked.read(reader)) as compressed_reader:
request.on_response(compressed_reader) request.on_response(compressed_reader)
@ -566,16 +573,18 @@ class MtProtoSender:
) )
return False return False
async def _handle_gzip_packed(self, msg_id, sequence, reader, state): async def _handle_gzip_packed(self, msg_id, sequence, reader, updates_handler):
""" """
Handles a GzipPacked response. Handles a GzipPacked response.
:param msg_id: the ID of the message. :param msg_id: the ID of the message.
:param sequence: the sequence of the message. :param sequence: the sequence of the message.
:param reader: the reader containing the GzipPacked. :param reader: the reader containing the GzipPacked.
:param updates_handler: the handler to process Update and Updates objects.
:return: the result of processing the packed message. :return: the result of processing the packed message.
""" """
__log__.debug('Handling gzip packed data')
with BinaryReader(GzipPacked.read(reader)) as compressed_reader: with BinaryReader(GzipPacked.read(reader)) as compressed_reader:
return await self._process_msg(msg_id, sequence, compressed_reader, state) return await self._process_msg(msg_id, sequence, compressed_reader, updates_handler)
# endregion # endregion

View File

@ -67,6 +67,22 @@ class Session(ABC):
""" """
raise NotImplementedError raise NotImplementedError
@property
@abstractmethod
def user_id(self):
"""
Returns an ``user_id`` which the session related to.
"""
raise NotImplementedError
@user_id.setter
@abstractmethod
def user_id(self, value):
"""
Sets the ``user_id`` which the session related to.
"""
raise NotImplementedError
@abstractmethod @abstractmethod
def get_update_state(self, entity_id): def get_update_state(self, entity_id):
""" """
@ -94,7 +110,7 @@ class Session(ABC):
""" """
@abstractmethod @abstractmethod
def save(self): async def save(self):
""" """
Called whenever important properties change. It should Called whenever important properties change. It should
make persist the relevant session information to disk. make persist the relevant session information to disk.
@ -102,7 +118,7 @@ class Session(ABC):
raise NotImplementedError raise NotImplementedError
@abstractmethod @abstractmethod
def delete(self): async def delete(self):
""" """
Called upon client.log_out(). Should delete the stored Called upon client.log_out(). Should delete the stored
information from disk since it's not valid anymore. information from disk since it's not valid anymore.
@ -125,7 +141,7 @@ class Session(ABC):
raise NotImplementedError raise NotImplementedError
@abstractmethod @abstractmethod
def get_input_entity(self, key): async def get_input_entity(self, key):
""" """
Turns the given key into an ``InputPeer`` (e.g. ``InputPeerUser``). Turns the given key into an ``InputPeer`` (e.g. ``InputPeerUser``).
The library uses this method whenever an ``InputPeer`` is needed The library uses this method whenever an ``InputPeer`` is needed
@ -135,7 +151,7 @@ class Session(ABC):
raise NotImplementedError raise NotImplementedError
@abstractmethod @abstractmethod
def cache_file(self, md5_digest, file_size, instance): async def cache_file(self, md5_digest, file_size, instance):
""" """
Caches the given file information persistently, so that it Caches the given file information persistently, so that it
doesn't need to be re-uploaded in case the file is used again. doesn't need to be re-uploaded in case the file is used again.
@ -146,7 +162,7 @@ class Session(ABC):
raise NotImplementedError raise NotImplementedError
@abstractmethod @abstractmethod
def get_file(self, md5_digest, file_size, cls): async def get_file(self, md5_digest, file_size, cls):
""" """
Returns an instance of ``cls`` if the ``md5_digest`` and ``file_size`` Returns an instance of ``cls`` if the ``md5_digest`` and ``file_size``
match an existing saved record. The class will either be an match an existing saved record. The class will either be an

View File

@ -32,6 +32,7 @@ class MemorySession(Session):
self._server_address = None self._server_address = None
self._port = None self._port = None
self._auth_key = None self._auth_key = None
self._user_id = None
self._files = {} self._files = {}
self._entities = set() self._entities = set()
@ -58,6 +59,14 @@ class MemorySession(Session):
def auth_key(self, value): def auth_key(self, value):
self._auth_key = value self._auth_key = value
@property
def user_id(self):
return self._user_id
@user_id.setter
def user_id(self, value):
self._user_id = value
def get_update_state(self, entity_id): def get_update_state(self, entity_id):
return self._update_states.get(entity_id, None) return self._update_states.get(entity_id, None)
@ -67,10 +76,10 @@ class MemorySession(Session):
def close(self): def close(self):
pass pass
def save(self): async def save(self):
pass pass
def delete(self): async def delete(self):
pass pass
def _entity_values_to_row(self, id, hash, username, phone, name): def _entity_values_to_row(self, id, hash, username, phone, name):
@ -170,7 +179,7 @@ class MemorySession(Session):
except StopIteration: except StopIteration:
pass pass
def get_input_entity(self, key): async def get_input_entity(self, key):
try: try:
if key.SUBCLASS_OF_ID in (0xc91c90b6, 0xe669bf46, 0x40f202fd): if key.SUBCLASS_OF_ID in (0xc91c90b6, 0xe669bf46, 0x40f202fd):
# hex(crc32(b'InputPeer', b'InputUser' and b'InputChannel')) # hex(crc32(b'InputPeer', b'InputUser' and b'InputChannel'))
@ -215,14 +224,14 @@ class MemorySession(Session):
else: else:
raise ValueError('Could not find input entity with key ', key) raise ValueError('Could not find input entity with key ', key)
def cache_file(self, md5_digest, file_size, instance): async def cache_file(self, md5_digest, file_size, instance):
if not isinstance(instance, (InputDocument, InputPhoto)): if not isinstance(instance, (InputDocument, InputPhoto)):
raise TypeError('Cannot cache %s instance' % type(instance)) raise TypeError('Cannot cache %s instance' % type(instance))
key = (md5_digest, file_size, _SentFileType.from_type(instance)) key = (md5_digest, file_size, _SentFileType.from_type(instance))
value = (instance.id, instance.access_hash) value = (instance.id, instance.access_hash)
self._files[key] = value self._files[key] = value
def get_file(self, md5_digest, file_size, cls): async def get_file(self, md5_digest, file_size, cls):
key = (md5_digest, file_size, _SentFileType.from_type(cls)) key = (md5_digest, file_size, _SentFileType.from_type(cls))
try: try:
return cls(self._files[key]) return cls(self._files[key])

View File

@ -213,7 +213,7 @@ class SQLiteSession(MemorySession):
)) ))
c.close() c.close()
def get_update_state(self, entity_id): async def get_update_state(self, entity_id):
c = self._cursor() c = self._cursor()
row = c.execute('select pts, qts, date, seq from update_state ' row = c.execute('select pts, qts, date, seq from update_state '
'where id = ?', (entity_id,)).fetchone() 'where id = ?', (entity_id,)).fetchone()
@ -223,7 +223,7 @@ class SQLiteSession(MemorySession):
date = datetime.datetime.utcfromtimestamp(date) date = datetime.datetime.utcfromtimestamp(date)
return types.updates.State(pts, qts, date, seq, unread_count=0) return types.updates.State(pts, qts, date, seq, unread_count=0)
def set_update_state(self, entity_id, state): async def set_update_state(self, entity_id, state):
c = self._cursor() c = self._cursor()
c.execute('insert or replace into update_state values (?,?,?,?,?)', c.execute('insert or replace into update_state values (?,?,?,?,?)',
(entity_id, state.pts, state.qts, (entity_id, state.pts, state.qts,
@ -231,7 +231,7 @@ class SQLiteSession(MemorySession):
c.close() c.close()
self.save() self.save()
def save(self): async def save(self):
"""Saves the current session object as session_user_id.session""" """Saves the current session object as session_user_id.session"""
self._conn.commit() self._conn.commit()
@ -248,7 +248,7 @@ class SQLiteSession(MemorySession):
self._conn.close() self._conn.close()
self._conn = None self._conn = None
def delete(self): async def delete(self):
"""Deletes the current session file""" """Deletes the current session file"""
if self.filename == ':memory:': if self.filename == ':memory:':
return True return True
@ -319,7 +319,7 @@ class SQLiteSession(MemorySession):
# File processing # File processing
def get_file(self, md5_digest, file_size, cls): async def get_file(self, md5_digest, file_size, cls):
tuple_ = self._cursor().execute( tuple_ = self._cursor().execute(
'select id, hash from sent_files ' 'select id, hash from sent_files '
'where md5_digest = ? and file_size = ? and type = ?', 'where md5_digest = ? and file_size = ? and type = ?',
@ -329,7 +329,7 @@ class SQLiteSession(MemorySession):
# Both allowed classes have (id, access_hash) as parameters # Both allowed classes have (id, access_hash) as parameters
return cls(tuple_[0], tuple_[1]) return cls(tuple_[0], tuple_[1])
def cache_file(self, md5_digest, file_size, instance): async def cache_file(self, md5_digest, file_size, instance):
if not isinstance(instance, (InputDocument, InputPhoto)): if not isinstance(instance, (InputDocument, InputPhoto)):
raise TypeError('Cannot cache %s instance' % type(instance)) raise TypeError('Cannot cache %s instance' % type(instance))

View File

@ -1,8 +1,8 @@
import asyncio import asyncio
import logging import logging
import os import os
from asyncio import Lock from asyncio import Lock, Event
from datetime import timedelta from datetime import timedelta, datetime
import platform import platform
from . import version, utils from . import version, utils
from .crypto import rsa from .crypto import rsa
@ -13,7 +13,7 @@ from .errors import (
RpcCallFailError RpcCallFailError
) )
from .network import authenticator, MtProtoSender, ConnectionTcpFull from .network import authenticator, MtProtoSender, ConnectionTcpFull
from .sessions import Session, SQLiteSession from .sessions import Session
from .tl import TLObject from .tl import TLObject
from .tl.all_tlobjects import LAYER from .tl.all_tlobjects import LAYER
from .tl.functions import ( from .tl.functions import (
@ -25,10 +25,17 @@ from .tl.functions.auth import (
from .tl.functions.help import ( from .tl.functions.help import (
GetCdnConfigRequest, GetConfigRequest GetCdnConfigRequest, GetConfigRequest
) )
from .tl.functions.updates import GetStateRequest from .tl.functions.updates import GetStateRequest, GetDifferenceRequest
from .tl.types import (
Pong, PeerUser, PeerChat, Message, Updates, UpdateShort, UpdateNewChannelMessage, UpdateEditChannelMessage,
UpdateDeleteChannelMessages, UpdateChannelTooLong, UpdateNewMessage, NewSessionCreated, UpdatesTooLong,
UpdateShortSentMessage, MessageEmpty, UpdateShortMessage, UpdateShortChatMessage, UpdatesCombined
)
from .tl.types.auth import ExportedAuthorization from .tl.types.auth import ExportedAuthorization
from .update_state import UpdateState from .tl.types.messages import AffectedMessages, AffectedHistory
from .tl.types.updates import DifferenceEmpty, DifferenceTooLong, DifferenceSlice
MAX_TIMEOUT = 15 # in seconds
DEFAULT_DC_ID = 4 DEFAULT_DC_ID = 4
DEFAULT_IPV4_IP = '149.154.167.51' DEFAULT_IPV4_IP = '149.154.167.51'
DEFAULT_IPV6_IP = '[2001:67c:4e8:f002::a]' DEFAULT_IPV6_IP = '[2001:67c:4e8:f002::a]'
@ -71,8 +78,11 @@ class TelegramBareClient:
use_ipv6=False, use_ipv6=False,
proxy=None, proxy=None,
timeout=timedelta(seconds=5), timeout=timedelta(seconds=5),
ping_delay=timedelta(minutes=1),
update_handler=None,
unauthorized_handler=None,
loop=None, loop=None,
report_errors=True, report_errors=None,
device_model=None, device_model=None,
system_version=None, system_version=None,
app_version=None, app_version=None,
@ -87,12 +97,8 @@ class TelegramBareClient:
self._use_ipv6 = use_ipv6 self._use_ipv6 = use_ipv6
# Determine what session object we have # Determine what session object we have
if isinstance(session, str) or session is None: if not isinstance(session, Session):
session = SQLiteSession(session) raise TypeError('The given session must be a Session instance.')
elif not isinstance(session, Session):
raise TypeError(
'The given session must be a str or a Session instance.'
)
self._loop = loop if loop else asyncio.get_event_loop() self._loop = loop if loop else asyncio.get_event_loop()
@ -105,7 +111,8 @@ class TelegramBareClient:
DEFAULT_PORT DEFAULT_PORT
) )
session.report_errors = report_errors if report_errors is not None:
session.report_errors = report_errors
self.session = session self.session = session
self.api_id = int(api_id) self.api_id = int(api_id)
self.api_hash = api_hash self.api_hash = api_hash
@ -128,10 +135,6 @@ class TelegramBareClient:
# expensive operation. # expensive operation.
self._exported_sessions = {} self._exported_sessions = {}
# This member will process updates if enabled.
# One may change self.updates.enabled at any later point.
self.updates = UpdateState(self._loop)
# Used on connection - the user may modify these and reconnect # Used on connection - the user may modify these and reconnect
system = platform.uname() system = platform.uname()
self.device_model = device_model or system.system or 'Unknown' self.device_model = device_model or system.system or 'Unknown'
@ -140,6 +143,12 @@ class TelegramBareClient:
self.lang_code = lang_code self.lang_code = lang_code
self.system_lang_code = system_lang_code self.system_lang_code = system_lang_code
self._state = None
self._sync_loading = False
self.update_handler = update_handler
self.unauthorized_handler = unauthorized_handler
self._last_update = datetime.now()
# Despite the state of the real connection, keep track of whether # Despite the state of the real connection, keep track of whether
# the user has explicitly called .connect() or .disconnect() here. # the user has explicitly called .connect() or .disconnect() here.
# This information is required by the read thread, who will be the # This information is required by the read thread, who will be the
@ -147,90 +156,72 @@ class TelegramBareClient:
# doesn't explicitly call .disconnect(), thus telling it to stop # doesn't explicitly call .disconnect(), thus telling it to stop
# retrying. The main thread, knowing there is a background thread # retrying. The main thread, knowing there is a background thread
# attempting reconnection as soon as it happens, will just sleep. # attempting reconnection as soon as it happens, will just sleep.
self._user_connected = False self._user_connected = Event(loop=self._loop)
self._authorized = False
# Save whether the user is authorized here (a.k.a. logged in) self._shutdown = False
self._authorized = None # None = We don't know yet
# The first request must be in invokeWithLayer(initConnection(X)).
# See https://core.telegram.org/api/invoking#saving-client-info.
self._first_request = True
self._recv_loop = None self._recv_loop = None
self._ping_loop = None self._ping_loop = None
self._state_loop = None self._reconnection_loop = None
self._idling = asyncio.Event()
# Default PingRequest delay if isinstance(ping_delay, timedelta):
self._ping_delay = timedelta(minutes=1) self._ping_delay = ping_delay.seconds
# Also have another delay for GetStateRequest. elif isinstance(ping_delay, (int, float)):
# self._ping_delay = float(ping_delay)
# If the connection is kept alive for long without invoking any else:
# high level request the server simply stops sending updates. raise TypeError('Invalid timeout type', type(timeout))
# TODO maybe we can have ._last_request instead if any req works?
self._state_delay = timedelta(hours=1)
# endregion def __del__(self):
self.disconnect()
# region Connecting
async def connect(self, _sync_updates=True):
"""Connects to the Telegram servers, executing authentication if
required. Note that authenticating to the Telegram servers is
not the same as authenticating the desired user itself, which
may require a call (or several) to 'sign_in' for the first time.
Note that the optional parameters are meant for internal use.
If '_sync_updates', sync_updates() will be called and a
second thread will be started if necessary. Note that this
will FAIL if the client is not connected to the user's
native data center, raising a "UserMigrateError", and
calling .disconnect() in the process.
"""
__log__.info('Connecting to %s:%d...',
self.session.server_address, self.session.port)
async def connect(self):
try: try:
await self._sender.connect() if not self._sender.is_connected():
__log__.info('Connection success!') await self._sender.connect()
if not self.session.auth_key:
# Connection was successful! Try syncing the update state
# UNLESS '_sync_updates' is False (we probably are in
# another data center and this would raise UserMigrateError)
# to also assert whether the user is logged in or not.
self._user_connected = True
if self._authorized is None and _sync_updates:
try: try:
await self.sync_updates() self.session.auth_key, self.session.time_offset = \
await self._set_connected_and_authorized() await authenticator.do_authentication(self._sender.connection)
await self.session.save()
except BrokenAuthKeyError:
self._user_connected.clear()
return False
if TelegramBareClient._config is None:
TelegramBareClient._config = await self(self._wrap_init_connection(GetConfigRequest()))
if not self._authorized:
try:
self._state = await self(self._wrap_init_connection(GetStateRequest()))
self._authorized = True
except UnauthorizedError: except UnauthorizedError:
self._authorized = False self._authorized = False
elif self._authorized:
await self._set_connected_and_authorized()
self.run_loops()
self._user_connected.set()
return True return True
except TypeNotFoundError as e: except TypeNotFoundError as e:
# This is fine, probably layer migration # This is fine, probably layer migration
__log__.warning('Connection failed, got unexpected type with ID ' __log__.warning('Connection failed, got unexpected type with ID '
'%s. Migrating?', hex(e.invalid_constructor_id)) '%s. Migrating?', hex(e.invalid_constructor_id))
self.disconnect() self.disconnect(False)
return await self.connect(_sync_updates=_sync_updates) return await self.connect()
except AuthKeyError as e: except AuthKeyError as e:
# As of late March 2018 there were two AUTH_KEY_DUPLICATED # As of late March 2018 there were two AUTH_KEY_DUPLICATED
# reports. Retrying with a clean auth_key should fix this. # reports. Retrying with a clean auth_key should fix this.
__log__.warning('Auth key error %s. Clearing it and retrying.', e) if not self._authorized:
self.disconnect() __log__.warning('Auth key error %s. Clearing it and retrying.', e)
self.session.auth_key = None self.disconnect(False)
self.session.save() self.session.auth_key = None
return self.connect(_sync_updates=_sync_updates) return self.connect()
else:
raise
except (RPCError, ConnectionError) as e: except (RPCError, ConnectionError) as e:
# Probably errors from the previous session, ignore them # Probably errors from the previous session, ignore them
__log__.error('Connection failed due to %s', e) __log__.error('Connection failed due to %s', e)
self.disconnect() self.disconnect(False)
return False return False
def is_connected(self): def is_connected(self):
@ -249,24 +240,11 @@ class TelegramBareClient:
query=query query=query
)) ))
def disconnect(self): def disconnect(self, shutdown=True):
"""Disconnects from the Telegram server""" """Disconnects from the Telegram server"""
__log__.info('Disconnecting...') self._shutdown = shutdown
self._user_connected = False self._user_connected.clear()
self._sender.disconnect() self._sender.disconnect(clear_pendings=shutdown)
if self._recv_loop:
self._recv_loop.cancel()
self._recv_loop = None
if self._ping_loop:
self._ping_loop.cancel()
self._ping_loop = None
if self._state_loop:
self._state_loop.cancel()
self._state_loop = None
# TODO Shall we clear the _exported_sessions, or may be reused?
self._first_request = True # On reconnect it will be first again
self.session.set_update_state(0, self.updates.get_update_state(0))
self.session.close()
async def _reconnect(self, new_dc=None): async def _reconnect(self, new_dc=None):
"""If 'new_dc' is not set, only a call to .connect() will be made """If 'new_dc' is not set, only a call to .connect() will be made
@ -277,32 +255,23 @@ class TelegramBareClient:
current data center, clears the auth key for the old DC, and current data center, clears the auth key for the old DC, and
connects to the new data center. connects to the new data center.
""" """
if new_dc is None: await self._reconnect_lock.acquire()
# Assume we are disconnected due to some error, so connect again try:
try: # Another thread may have connected again, so check that first
if self.is_connected(): if self.is_connected() and new_dc is None:
__log__.info('Reconnection aborted: already connected') return True
return True
__log__.info('Attempting reconnection...') if new_dc is not None:
return await self.connect() dc = await self._get_dc(new_dc)
except ConnectionResetError as e: self.disconnect(False)
__log__.warning('Reconnection failed due to %s', e) self.session.set_dc(dc.id, dc.ip_address, dc.port)
return False await self.session.save()
else:
# Since we're reconnecting possibly due to a UserMigrateError,
# we need to first know the Data Centers we can connect to. Do
# that before disconnecting.
dc = await self._get_dc(new_dc)
__log__.info('Reconnecting to new data center %s', dc)
self.session.set_dc(dc.id, dc.ip_address, dc.port)
# auth_key's are associated with a server, which has now changed
# so it's not valid anymore. Set to None to force recreating it.
self.session.auth_key = None
self.session.save()
self.disconnect()
return await self.connect() return await self.connect()
except (ConnectionResetError, TimeoutError):
return False
finally:
self._reconnect_lock.release()
def set_proxy(self, proxy): def set_proxy(self, proxy):
"""Change the proxy used by the connections. """Change the proxy used by the connections.
@ -348,7 +317,7 @@ class TelegramBareClient:
""" """
# Thanks badoualy/kotlogram on /telegram/api/DefaultTelegramClient.kt # Thanks badoualy/kotlogram on /telegram/api/DefaultTelegramClient.kt
# for clearly showing how to export the authorization! ^^ # for clearly showing how to export the authorization! ^^
session = self._exported_sessions.get(dc_id) session = self._exported_sessions.get(dc_id, None)
if session: if session:
export_auth = None # Already bound with the auth key export_auth = None # Already bound with the auth key
else: else:
@ -377,7 +346,7 @@ class TelegramBareClient:
timeout=self._sender.connection.get_timeout(), timeout=self._sender.connection.get_timeout(),
loop=self._loop loop=self._loop
) )
await client.connect(_sync_updates=False) await client.connect()
if isinstance(export_auth, ExportedAuthorization): if isinstance(export_auth, ExportedAuthorization):
await client(ImportAuthorizationRequest( await client(ImportAuthorizationRequest(
id=export_auth.id, bytes=export_auth.bytes id=export_auth.id, bytes=export_auth.bytes
@ -385,12 +354,11 @@ class TelegramBareClient:
elif export_auth is not None: elif export_auth is not None:
__log__.warning('Unknown export auth type %s', export_auth) __log__.warning('Unknown export auth type %s', export_auth)
client._authorized = True # We exported the auth, so we got auth
return client return client
async def _get_cdn_client(self, cdn_redirect): async def _get_cdn_client(self, cdn_redirect):
"""Similar to ._get_exported_client, but for CDNs""" """Similar to ._get_exported_client, but for CDNs"""
session = self._exported_sessions.get(cdn_redirect.dc_id) session = self._exported_sessions.get(cdn_redirect.dc_id, None)
if not session: if not session:
dc = await self._get_dc(cdn_redirect.dc_id, cdn=True) dc = await self._get_dc(cdn_redirect.dc_id, cdn=True)
session = self.session.clone() session = self.session.clone()
@ -410,8 +378,7 @@ class TelegramBareClient:
# We won't be calling GetConfigRequest because it's only called # We won't be calling GetConfigRequest because it's only called
# when needed by ._get_dc, and also it's static so it's likely # when needed by ._get_dc, and also it's static so it's likely
# set already. Avoid invoking non-CDN methods by not syncing updates. # set already. Avoid invoking non-CDN methods by not syncing updates.
await client.connect(_sync_updates=False) await client.connect()
client._authorized = self._authorized
return client return client
# endregion # endregion
@ -461,100 +428,70 @@ class TelegramBareClient:
which = '{} requests ({})'.format( which = '{} requests ({})'.format(
len(request), [type(x).__name__ for x in request]) len(request), [type(x).__name__ for x in request])
is_ping = any(isinstance(x, PingRequest) for x in request)
msg_ids = []
__log__.debug('Invoking %s', which) __log__.debug('Invoking %s', which)
call_receive = \ try:
not self._idling.is_set() or self._reconnect_lock.locked() for retry in range(retries):
result = None
for sub_retry in range(retries):
msg_ids, result = await self._invoke(retry, request, ordered, msg_ids)
if msg_ids:
break
if not self.is_connected():
break
__log__.error('Subretry %d is failed' % sub_retry)
if result is None:
if not is_ping:
try:
pong = await self(PingRequest(int.from_bytes(os.urandom(8), 'big', signed=True)), retries=1)
if isinstance(pong, Pong):
__log__.error('Connection is live, but no answer on %d retry' % retry)
continue
except RuntimeError:
pass # continue to reconnect
if self.is_connected() and (retry + 1) % 2 == 0:
__log__.error('Force disconnect on %d retry' % retry)
self.disconnect(False)
self._sender.forget_pendings(msg_ids)
msg_ids = []
if not self.is_connected():
__log__.error('Pause before new retry on %d retry' % retry)
await asyncio.sleep(retry + 1, loop=self._loop)
else:
return result[0] if single else result
finally:
self._sender.forget_pendings(msg_ids)
for retry in range(retries): raise RuntimeError('Number of retries is exceeded for {}.'.format(which))
result = await self._invoke(call_receive, retry, request,
ordered=ordered)
if result is not None:
return result[0] if single else result
log = __log__.info if retry == 0 else __log__.warning
log('Invoking %s failed %d times, connecting again and retrying',
which, retry + 1)
await asyncio.sleep(1)
if not self._reconnect_lock.locked():
with await self._reconnect_lock:
await self._reconnect()
raise RuntimeError('Number of retries reached 0 for {}.'.format(
which
))
# Let people use client.invoke(SomeRequest()) instead client(...) # Let people use client.invoke(SomeRequest()) instead client(...)
invoke = __call__ invoke = __call__
async def _invoke(self, call_receive, retry, requests, ordered=False): async def _invoke(self, retry, requests, ordered, msg_ids):
try: try:
# Ensure that we start with no previous errors (i.e. resending) if not msg_ids:
for x in requests: msg_ids = await self._sender.send(requests, ordered)
x.rpc_error = None
if not self.session.auth_key: # Ensure that we start with no previous errors (i.e. resending)
__log__.info('Need to generate new auth key before invoking') for x in requests:
self._first_request = True x.rpc_error = None
self.session.auth_key, self.session.time_offset = \
await authenticator.do_authentication(self._sender.connection)
if self._first_request: if self._user_connected.is_set():
__log__.info('Initializing a new connection while invoking') fut = asyncio.gather(*list(map(lambda x: x.confirm_received.wait(), requests)), loop=self._loop)
if len(requests) == 1: self._loop.call_later(self._sender.connection.get_timeout(), fut.cancel)
requests = [self._wrap_init_connection(requests[0])] await fut
else:
# We need a SINGLE request (like GetConfig) to init conn.
# Once that's done, the N original requests will be
# invoked.
TelegramBareClient._config = await self(
self._wrap_init_connection(GetConfigRequest())
)
await self._sender.send(requests, ordered=ordered)
if not call_receive:
await asyncio.wait(
list(map(lambda x: x.confirm_received.wait(), requests)),
timeout=self._sender.connection.get_timeout(),
loop=self._loop
)
else: else:
while not all(x.confirm_received.is_set() for x in requests): while not all(x.confirm_received.is_set() for x in requests):
await self._sender.receive(update_state=self.updates) await self._sender.receive(self._updates_handler)
except (TimeoutError, asyncio.CancelledError):
except BrokenAuthKeyError: __log__.error('Timeout on %d retry' % retry)
__log__.error('Authorization key seems broken and was invalid!')
self.session.auth_key = None
except TypeNotFoundError as e:
# Only occurs when we call receive. May happen when
# we need to reconnect to another DC on login and
# Telegram somehow sends old objects (like configOld)
self._first_request = True
__log__.warning('Read unknown TLObject code ({}). '
'Setting again first_request flag.'
.format(hex(e.invalid_constructor_id)))
except TimeoutError:
__log__.warning('Invoking timed out') # We will just retry
except ConnectionResetError as e: except ConnectionResetError as e:
__log__.warning('Connection was reset while invoking') if self._shutdown:
if self._user_connected: raise
# Server disconnected us, __call__ will try reconnecting. __log__.error('Connection reset on %d retry: %r' % (retry, e))
try:
self._sender.disconnect()
except:
pass
return None
else:
# User never called .connect(), so raise this error.
raise RuntimeError('Tried to invoke without .connect()') from e
# Clear the flag if we got this far
self._first_request = False
try: try:
raise next(x.rpc_error for x in requests if x.rpc_error) raise next(x.rpc_error for x in requests if x.rpc_error)
@ -562,15 +499,32 @@ class TelegramBareClient:
if any(x.result is None for x in requests): if any(x.result is None for x in requests):
# "A container may only be accepted or # "A container may only be accepted or
# rejected by the other party as a whole." # rejected by the other party as a whole."
return None return msg_ids, None
return [x.result for x in requests] for req in requests:
if isinstance(req.result, TLObject) and req.result.SUBCLASS_OF_ID == Updates.SUBCLASS_OF_ID:
self._updates_handler(req.result, False, False)
if isinstance(req.result, (AffectedMessages, AffectedHistory)): # due to affect to pts
self._updates_handler(UpdateShort(req.result, None), False, False)
except (PhoneMigrateError, NetworkMigrateError, return msg_ids, [x.result for x in requests]
UserMigrateError) as e:
except (PhoneMigrateError, NetworkMigrateError, UserMigrateError) as e:
if isinstance(e, (PhoneMigrateError, NetworkMigrateError)):
if self._authorized:
raise
else:
self.session.auth_key = None # Force creating new auth_key
__log__.error(
'DC error when invoking request, '
'attempting to reconnect at DC {}'.format(e.new_dc)
)
await self._reconnect(new_dc=e.new_dc) await self._reconnect(new_dc=e.new_dc)
return await self._invoke(call_receive, retry, requests) self._sender.forget_pendings(msg_ids)
msg_ids = []
return msg_ids, None
except (ServerError, RpcCallFailError) as e: except (ServerError, RpcCallFailError) as e:
# Telegram is having some issues, just retry # Telegram is having some issues, just retry
@ -582,7 +536,16 @@ class TelegramBareClient:
raise raise
await asyncio.sleep(e.seconds, loop=self._loop) await asyncio.sleep(e.seconds, loop=self._loop)
return None return msg_ids, None
except UnauthorizedError:
if self._authorized:
__log__.error('Authorization has lost')
self._authorized = False
self.disconnect()
if self.unauthorized_handler:
await self.unauthorized_handler(self)
raise
# Some really basic functionality # Some really basic functionality
@ -602,74 +565,216 @@ class TelegramBareClient:
# region Updates handling # region Updates handling
async def sync_updates(self): async def _handle_update(self, update, seq_start, seq, date, do_get_diff, do_handlers, users=(), chats=()):
"""Synchronizes self.updates to their initial state. Will be if isinstance(update, (UpdateNewChannelMessage, UpdateEditChannelMessage,
called automatically on connection if self.updates.enabled = True, UpdateDeleteChannelMessages, UpdateChannelTooLong)):
otherwise it should be called manually after enabling updates. # TODO: channel updates have their own pts sequences, so requires individual pts'es
""" return # ignore channel updates to keep pts in the main _state in the correct state
self.updates.process(await self(GetStateRequest())) if hasattr(update, 'pts'):
new_pts = self._state.pts + getattr(update, 'pts_count', 0)
if new_pts < update.pts:
__log__.debug('Have got a hole between pts => waiting 0.5 sec')
await asyncio.sleep(0.5, loop=self._loop)
if new_pts < update.pts:
if do_get_diff and not self._sync_loading:
__log__.debug('The hole between pts has not disappeared => going to get differences')
self._sync_loading = True
asyncio.ensure_future(self._get_difference(), loop=self._loop)
return
if update.pts > self._state.pts:
self._state.pts = update.pts
elif getattr(update, 'pts_count', 0) > 0:
__log__.debug('Have got the duplicate update (basing on pts) => ignoring')
return
elif hasattr(update, 'qts'):
if self._state.qts + 1 < update.qts:
__log__.debug('Have got a hole between qts => waiting 0.5 sec')
await asyncio.sleep(0.5, loop=self._loop)
if self._state.qts + 1 < update.qts:
if do_get_diff and not self._sync_loading:
__log__.debug('The hole between qts has not disappeared => going to get differences')
self._sync_loading = True
asyncio.ensure_future(self._get_difference(), loop=self._loop)
return
if update.qts > self._state.qts:
self._state.qts = update.qts
else:
__log__.debug('Have got the duplicate update (basing on qts) => ignoring')
return
elif seq > 0:
if seq_start > self._state.seq + 1:
__log__.debug('Have got a hole between seq => waiting 0.5 sec')
await asyncio.sleep(0.5, loop=self._loop)
if seq_start > self._state.seq + 1:
if do_get_diff and not self._sync_loading:
__log__.debug('The hole between seq has not disappeared => going to get differences')
self._sync_loading = True
asyncio.ensure_future(self._get_difference(), loop=self._loop)
return
self._state.seq = seq
self._state.date = max(self._state.date, date)
# endregion if do_handlers and self.update_handler:
asyncio.ensure_future(self.update_handler(self, update, users, chats), loop=self._loop)
# Constant read async def _get_difference(self):
self._sync_loading = True
try:
difference = await self(GetDifferenceRequest(self._state.pts, self._state.date, self._state.qts))
if isinstance(difference, DifferenceEmpty):
__log__.debug('Have got DifferenceEmpty => just update seq and date')
self._state.seq = difference.seq
self._state.date = difference.date
return
if isinstance(difference, DifferenceTooLong):
__log__.debug('Have got DifferenceTooLong => update pts and try again')
self._state.pts = difference.pts
asyncio.ensure_future(self._get_difference(), loop=self._loop)
return
__log__.debug('Preparing updates from differences')
self._state = difference.intermediate_state \
if isinstance(difference, DifferenceSlice) else difference.state
messages = [UpdateNewMessage(msg, self._state.pts, 0) for msg in difference.new_messages]
self._updates_handler(
Updates(messages + difference.other_updates,
difference.users, difference.chats, self._state.date, self._state.seq),
False
)
if isinstance(difference, DifferenceSlice):
asyncio.ensure_future(self._get_difference(), loop=self._loop)
except ConnectionResetError: # it happens on unauth due to _get_difference is often on the background
pass
except Exception as e:
__log__.exception('Exception on _get_difference: %r', e)
finally:
self._sync_loading = False
# This is async so that the overrided version in TelegramClient can be # TODO: Some of logic was moved from MtProtoSender and probably must be moved back.
# async without problems. def _updates_handler(self, updates, do_get_diff=True, do_handlers=True):
async def _set_connected_and_authorized(self): if do_get_diff:
self._authorized = True self._last_update = datetime.now()
if isinstance(updates, NewSessionCreated):
self.session.salt = updates.server_salt
if self._state is None:
return False # not ready yet
if self._sync_loading and do_get_diff:
return False # ignore all if in sync except from difference (do_get_diff = False)
if isinstance(updates, (NewSessionCreated, UpdatesTooLong)):
if do_get_diff: # to prevent possible loops
__log__.debug('Have got %s => going to get differences', type(updates))
self._sync_loading = True
asyncio.ensure_future(self._get_difference(), loop=self._loop)
return False
seq = getattr(updates, 'seq', 0)
seq_start = getattr(updates, 'seq_start', seq)
date = getattr(updates, 'date', self._state.date)
if isinstance(updates, UpdateShort):
asyncio.ensure_future(
self._handle_update(updates.update, seq_start, seq, date, do_get_diff, do_handlers),
loop=self._loop
)
return True
if isinstance(updates, UpdateShortSentMessage):
asyncio.ensure_future(self._handle_update(
UpdateNewMessage(MessageEmpty(updates.id), updates.pts, updates.pts_count),
seq_start, seq, date, do_get_diff, do_handlers
), loop=self._loop)
return True
if isinstance(updates, (UpdateShortMessage, UpdateShortChatMessage)):
from_id = getattr(updates, 'from_id', self.session.user_id)
to_id = updates.user_id if isinstance(updates, UpdateShortMessage) else updates.chat_id
if not updates.out:
from_id, to_id = to_id, from_id
to_id = PeerUser(to_id) if isinstance(updates, UpdateShortMessage) else PeerChat(to_id)
message = Message(
id=updates.id, to_id=to_id, date=updates.date, message=updates.message, out=updates.out,
mentioned=updates.mentioned, media_unread=updates.media_unread, silent=updates.silent,
from_id=from_id, fwd_from=updates.fwd_from, via_bot_id=updates.via_bot_id,
reply_to_msg_id=updates.reply_to_msg_id, entities=updates.entities
)
asyncio.ensure_future(self._handle_update(
UpdateNewMessage(message, updates.pts, updates.pts_count),
seq_start, seq, date, do_get_diff, do_handlers
), loop=self._loop)
return True
if isinstance(updates, (Updates, UpdatesCombined)):
for upd in updates.updates:
asyncio.ensure_future(
self._handle_update(upd, seq_start, seq, date, do_get_diff, do_handlers,
updates.users, updates.chats),
loop=self._loop
)
return True
if do_get_diff: # to prevent possible loops
__log__.debug('Have got unsupported type of updates: %s => going to get differences', type(updates))
self._sync_loading = True
asyncio.ensure_future(self._get_difference(), loop=self._loop)
return False
def run_loops(self):
if self._recv_loop is None: if self._recv_loop is None:
self._recv_loop = asyncio.ensure_future(self._recv_loop_impl(), loop=self._loop) self._recv_loop = asyncio.ensure_future(self._recv_loop_impl(), loop=self._loop)
if self._ping_loop is None: if self._ping_loop is None:
self._ping_loop = asyncio.ensure_future(self._ping_loop_impl(), loop=self._loop) self._ping_loop = asyncio.ensure_future(self._ping_loop_impl(), loop=self._loop)
if self._state_loop is None:
self._state_loop = asyncio.ensure_future(self._state_loop_impl(), loop=self._loop)
async def _ping_loop_impl(self): async def _ping_loop_impl(self):
while self._user_connected: while True:
await self(PingRequest(int.from_bytes(os.urandom(8), 'big', signed=True))) if self._shutdown:
await asyncio.sleep(self._ping_delay.seconds, loop=self._loop) break
try:
await self._user_connected.wait()
await self(PingRequest(int.from_bytes(os.urandom(8), 'big', signed=True)))
await asyncio.sleep(self._ping_delay, loop=self._loop)
except RuntimeError:
pass # Can be not happy due to connection problems
except asyncio.CancelledError:
break
except:
self._ping_loop = None
raise
self._ping_loop = None self._ping_loop = None
async def _state_loop_impl(self):
while self._user_connected:
await asyncio.sleep(self._state_delay.seconds, loop=self._loop)
await self._sender.send(GetStateRequest())
async def _recv_loop_impl(self): async def _recv_loop_impl(self):
__log__.info('Starting to wait for items from the network') timeout = 1
self._idling.set() while True:
need_reconnect = False if self._shutdown:
while self._user_connected: break
try: try:
if need_reconnect: if self._user_connected.is_set():
__log__.info('Attempting reconnection from read loop') if self._authorized and datetime.now() - self._last_update > timedelta(minutes=15):
need_reconnect = False __log__.debug('No updates for 15 minutes => going to get differences')
with await self._reconnect_lock: self._last_update = datetime.now()
while self._user_connected and not await self._reconnect(): self._sync_loading = True
# Retry forever, this is instant messaging asyncio.ensure_future(self._get_difference(), loop=self._loop)
await asyncio.sleep(0.1, loop=self._loop) await self._sender.receive(self._updates_handler)
else:
# Telegram seems to kick us every 1024 items received if await self._reconnect():
# from the network not considering things like bad salt. __log__.info('Connection has established')
# We must execute some *high level* request (that's not timeout = 1
# a ping) if we want to receive updates again. else:
# TODO Test if getDifference works too (better alternative) await asyncio.sleep(timeout, loop=self._loop)
await self._sender.send(GetStateRequest()) timeout = min(timeout * 2, MAX_TIMEOUT)
__log__.debug('Receiving items from the network...')
await self._sender.receive(update_state=self.updates)
except TimeoutError: except TimeoutError:
# No problem. # No problem.
__log__.debug('Receiving items from the network timed out') pass
except ConnectionError: except ConnectionResetError as error:
need_reconnect = True __log__.info('Connection reset error in recv loop: %r' % error)
__log__.error('Connection was reset while receiving items') self._user_connected.clear()
await asyncio.sleep(1, loop=self._loop) except asyncio.CancelledError:
except: self.disconnect()
self._idling.clear() break
raise except Exception as error:
# Unknown exception, pass it to the main thread
self._idling.clear() __log__.exception('[ERROR: %r] on the read loop, please report', error)
__log__.info('Connection closed by the user, not reading anymore') self._recv_loop = None
if self._shutdown and self._ping_loop:
self._ping_loop.cancel()
# endregion # endregion

View File

@ -510,7 +510,7 @@ class TelegramClient(TelegramBareClient):
return False return False
self.disconnect() self.disconnect()
self.session.delete() await self.session.delete()
self._authorized = False self._authorized = False
return True return True
@ -1805,7 +1805,7 @@ class TelegramClient(TelegramBareClient):
to_cache = utils.get_input_photo(msg.media.photo) to_cache = utils.get_input_photo(msg.media.photo)
else: else:
to_cache = utils.get_input_document(msg.media.document) to_cache = utils.get_input_document(msg.media.document)
self.session.cache_file(md5, size, to_cache) await self.session.cache_file(md5, size, to_cache)
return msg return msg
@ -1849,7 +1849,7 @@ class TelegramClient(TelegramBareClient):
input_photo = utils.get_input_photo((await self(UploadMediaRequest( input_photo = utils.get_input_photo((await self(UploadMediaRequest(
entity, media=InputMediaUploadedPhoto(fh) entity, media=InputMediaUploadedPhoto(fh)
))).photo) ))).photo)
self.session.cache_file(fh.md5, fh.size, input_photo) await self.session.cache_file(fh.md5, fh.size, input_photo)
fh = input_photo fh = input_photo
if captions: if captions:
@ -1957,7 +1957,7 @@ class TelegramClient(TelegramBareClient):
file = stream.read() file = stream.read()
hash_md5.update(file) hash_md5.update(file)
if use_cache: if use_cache:
cached = self.session.get_file( cached = await self.session.get_file(
hash_md5.digest(), file_size, cls=use_cache hash_md5.digest(), file_size, cls=use_cache
) )
if cached: if cached:
@ -2122,7 +2122,7 @@ class TelegramClient(TelegramBareClient):
media, file, date, progress_callback media, file, date, progress_callback
) )
elif isinstance(media, MessageMediaContact): elif isinstance(media, MessageMediaContact):
return await self._download_contact( return self._download_contact(
media, file media, file
) )
@ -2472,7 +2472,7 @@ class TelegramClient(TelegramBareClient):
be passed instead. be passed instead.
""" """
self.updates.handler = self._on_handler self.update_handler = self._on_handler
if isinstance(event, type): if isinstance(event, type):
event = event() event = event()
elif not event: elif not event:
@ -2555,7 +2555,7 @@ class TelegramClient(TelegramBareClient):
# infinite loop here (so check against old pts to stop) # infinite loop here (so check against old pts to stop)
break break
self.updates.process(Updates( self._updates_handler(Updates(
users=d.users, users=d.users,
chats=d.chats, chats=d.chats,
date=state.date, date=state.date,
@ -2573,10 +2573,6 @@ class TelegramClient(TelegramBareClient):
# region Small utilities to make users' life easier # region Small utilities to make users' life easier
async def _set_connected_and_authorized(self):
await super()._set_connected_and_authorized()
await self._check_events_pending_resolve()
async def get_entity(self, entity): async def get_entity(self, entity):
""" """
Turns the given entity into a valid Telegram user or chat. Turns the given entity into a valid Telegram user or chat.
@ -2694,7 +2690,7 @@ class TelegramClient(TelegramBareClient):
try: try:
# Nobody with this username, maybe it's an exact name/title # Nobody with this username, maybe it's an exact name/title
return await self.get_entity( return await self.get_entity(
self.session.get_input_entity(string)) await self.session.get_input_entity(string))
except ValueError: except ValueError:
pass pass
@ -2729,7 +2725,7 @@ class TelegramClient(TelegramBareClient):
try: try:
# First try to get the entity from cache, otherwise figure it out # First try to get the entity from cache, otherwise figure it out
return self.session.get_input_entity(peer) return await self.session.get_input_entity(peer)
except ValueError: except ValueError:
pass pass

View File

@ -192,3 +192,6 @@ class TLObject:
@classmethod @classmethod
def from_reader(cls, reader): def from_reader(cls, reader):
return TLObject() return TLObject()
def __repr__(self):
return self.__str__()

View File

@ -1,65 +0,0 @@
import asyncio
import itertools
import logging
from datetime import datetime
from . import utils
from .tl import types as tl
__log__ = logging.getLogger(__name__)
class UpdateState:
"""
Used to hold the current state of processed updates.
To retrieve an update, :meth:`poll` should be called.
"""
WORKER_POLL_TIMEOUT = 5.0 # Avoid waiting forever on the workers
def __init__(self, loop=None):
self.handler = None
self._loop = loop if loop else asyncio.get_event_loop()
# https://core.telegram.org/api/updates
self._state = tl.updates.State(0, 0, datetime.now(), 0, 0)
def handle_update(self, update):
if self.handler:
asyncio.ensure_future(self.handler(update), loop=self._loop)
def get_update_state(self, entity_id):
"""Gets the updates.State corresponding to the given entity or 0."""
return self._state
def process(self, update):
"""Processes an update object. This method is normally called by
the library itself.
"""
if isinstance(update, tl.updates.State):
__log__.debug('Saved new updates state')
self._state = update
return # Nothing else to be done
if hasattr(update, 'pts'):
self._state.pts = update.pts
# After running the script for over an hour and receiving over
# 1000 updates, the only duplicates received were users going
# online or offline. We can trust the server until new reports.
# This should only be used as read-only.
if isinstance(update, tl.UpdateShort):
update.update._entities = {}
self.handle_update(update.update)
# Expand "Updates" into "Update", and pass these to callbacks.
# Since .users and .chats have already been processed, we
# don't need to care about those either.
elif isinstance(update, (tl.Updates, tl.UpdatesCombined)):
entities = {utils.get_peer_id(x): x for x in
itertools.chain(update.users, update.chats)}
for u in update.updates:
u._entities = entities
self.handle_update(u)
# TODO Handle "tl.UpdatesTooLong"
else:
update._entities = {}
self.handle_update(update)