diff --git a/telethon/extensions/tcp_client.py b/telethon/extensions/tcp_client.py index e9a86e15..999877f6 100644 --- a/telethon/extensions/tcp_client.py +++ b/telethon/extensions/tcp_client.py @@ -6,32 +6,33 @@ import asyncio import errno import logging import socket -import time from datetime import timedelta from io import BytesIO, BufferedWriter -MAX_TIMEOUT = 15 # in seconds CONN_RESET_ERRNOS = { errno.EBADF, errno.ENOTSOCK, errno.ENETUNREACH, - errno.EINVAL, errno.ENOTCONN + errno.EINVAL, errno.ENOTCONN, errno.EHOSTUNREACH, + errno.ECONNREFUSED, errno.ECONNRESET, errno.ECONNABORTED, + errno.ENETDOWN, errno.ENETRESET, errno.ECONNABORTED, + errno.EHOSTDOWN, errno.EPIPE, errno.ESHUTDOWN } +# catched: EHOSTUNREACH, ECONNREFUSED, ECONNRESET, ENETUNREACH +# ConnectionError: EPIPE, ESHUTDOWN, ECONNABORTED, ECONNREFUSED, ECONNRESET try: import socks except ImportError: socks = None -MAX_TIMEOUT = 15 # in seconds -CONN_RESET_ERRNOS = { - errno.EBADF, errno.ENOTSOCK, errno.ENETUNREACH, - errno.EINVAL, errno.ENOTCONN -} - __log__ = logging.getLogger(__name__) class TcpClient: """A simple TCP client to ease the work with sockets and proxies.""" + + class SocketClosed(ConnectionError): + pass + def __init__(self, proxy=None, timeout=timedelta(seconds=5), loop=None): """ Initializes the TCP client. @@ -42,6 +43,8 @@ class TcpClient: self.proxy = proxy self._socket = None self._loop = loop if loop else asyncio.get_event_loop() + self._closed = asyncio.Event(loop=self._loop) + self._closed.set() if isinstance(timeout, timedelta): self.timeout = timeout.seconds @@ -76,41 +79,28 @@ class TcpClient: else: mode, address = socket.AF_INET, (ip, port) - timeout = 1 - while True: - try: - if not self._socket: - self._recreate_socket(mode) + try: + if not self._socket: + self._recreate_socket(mode) - await self._loop.sock_connect(self._socket, address) - break # Successful connection, stop retrying to connect - except ConnectionError: - self._socket = None - await asyncio.sleep(timeout) - timeout = min(timeout * 2, MAX_TIMEOUT) - except OSError as e: - __log__.info('OSError "%s" raised while connecting', e) - # Stop retrying to connect if proxy connection error occurred - if socks and isinstance(e, socks.ProxyConnectionError): - raise - # There are some errors that we know how to handle, and - # the loop will allow us to retry - if e.errno in (errno.EBADF, errno.ENOTSOCK, errno.EINVAL, - errno.ECONNREFUSED, # Windows-specific follow - getattr(errno, 'WSAEACCES', None)): - # Bad file descriptor, i.e. socket was closed, set it - # to none to recreate it on the next iteration - self._socket = None - await asyncio.sleep(timeout) - timeout *= 2 - if timeout > MAX_TIMEOUT: - raise - else: - raise + await asyncio.wait_for( + self._loop.sock_connect(self._socket, address), + timeout=self.timeout, + loop=self._loop + ) + + self._closed.clear() + except asyncio.TimeoutError as e: + raise TimeoutError() from e + except OSError as e: + if e.errno in CONN_RESET_ERRNOS: + self._raise_connection_reset(e) + else: + raise def _get_connected(self): """Determines whether the client is connected or not.""" - return self._socket is not None and self._socket.fileno() >= 0 + return not self._closed.is_set() connected = property(fget=_get_connected) @@ -118,12 +108,29 @@ class TcpClient: """Closes the connection.""" try: if self._socket is not None: - self._socket.shutdown(socket.SHUT_RDWR) + if self.connected: + self._socket.shutdown(socket.SHUT_RDWR) self._socket.close() except OSError: pass # Ignore ENOTCONN, EBADF, and any other error when closing finally: self._socket = None + self._closed.set() + + async def _wait_close(self, coro): + done, running = await asyncio.wait( + [coro, self._closed.wait()], + timeout=self.timeout, + return_when=asyncio.FIRST_COMPLETED, + loop=self._loop + ) + for r in running: + r.cancel() + if not self.connected: + raise self.SocketClosed() + if not done: + raise TimeoutError() + return done.pop().result() async def write(self, data): """ @@ -131,21 +138,12 @@ class TcpClient: :param data: the data to send. """ - if self._socket is None: - self._raise_connection_reset(None) - + if not self.connected: + raise ConnectionResetError('No connection') try: - await asyncio.wait_for( - self.sock_sendall(data), - timeout=self.timeout, - loop=self._loop - ) - except asyncio.TimeoutError as e: - __log__.debug('socket.timeout "%s" while writing data', e) - raise TimeoutError() from e - except ConnectionError as e: - __log__.info('ConnectionError "%s" while writing data', e) - self._raise_connection_reset(e) + await self._wait_close(self.sock_sendall(data)) + except self.SocketClosed: + raise ConnectionResetError('Socket has closed') except OSError as e: __log__.info('OSError "%s" while writing data', e) if e.errno in CONN_RESET_ERRNOS: @@ -160,21 +158,15 @@ class TcpClient: :param size: the size of the block to be read. :return: the read data with len(data) == size. """ - if self._socket is None: - self._raise_connection_reset(None) - with BufferedWriter(BytesIO(), buffer_size=size) as buffer: bytes_left = size + partial = b'' while bytes_left != 0: + if not self.connected: + raise ConnectionResetError('No connection') try: - if self._socket is None: - self._raise_connection_reset() - partial = await asyncio.wait_for( - self.sock_recv(bytes_left), - timeout=self.timeout, - loop=self._loop - ) - except asyncio.TimeoutError as e: + partial = await self._wait_close(self.sock_recv(bytes_left)) + except TimeoutError as e: # These are somewhat common if the server has nothing # to send to us, so use a lower logging priority. if bytes_left < size: @@ -187,10 +179,9 @@ class TcpClient: 'socket.timeout "%s" while reading data', e ) - raise TimeoutError() from e - except ConnectionError as e: - __log__.info('ConnectionError "%s" while reading data', e) - self._raise_connection_reset(e) + raise + except self.SocketClosed: + raise ConnectionResetError('Socket has closed while reading data') except OSError as e: if e.errno != errno.EBADF: # Ignore bad file descriptor while closing @@ -202,7 +193,7 @@ class TcpClient: raise if len(partial) == 0: - self._raise_connection_reset(None) + self._raise_connection_reset('No data on read') buffer.write(partial) bytes_left -= len(partial) @@ -211,10 +202,12 @@ class TcpClient: buffer.flush() return buffer.raw.getvalue() - def _raise_connection_reset(self, original): - """Disconnects the client and raises ConnectionResetError.""" + def _raise_connection_reset(self, error): + description = error if isinstance(error, str) else str(error) + if isinstance(error, str): + error = Exception(error) self.close() # Connection reset -> flag as socket closed - raise ConnectionResetError('The server has closed the connection.') from original + raise ConnectionResetError(description) from error # due to new https://github.com/python/cpython/pull/4386 def sock_recv(self, n): @@ -225,7 +218,7 @@ class TcpClient: def _sock_recv(self, fut, registered_fd, n): if registered_fd is not None: self._loop.remove_reader(registered_fd) - if fut.cancelled(): + if fut.cancelled() or self._socket is None: return try: @@ -249,7 +242,7 @@ class TcpClient: def _sock_sendall(self, fut, registered_fd, data): if registered_fd: self._loop.remove_writer(registered_fd) - if fut.cancelled(): + if fut.cancelled() or self._socket is None: return try: diff --git a/telethon/network/mtproto_sender.py b/telethon/network/mtproto_sender.py index a413fd64..2d7ddf9d 100644 --- a/telethon/network/mtproto_sender.py +++ b/telethon/network/mtproto_sender.py @@ -14,11 +14,11 @@ from ..errors import ( from ..extensions import BinaryReader from ..tl import TLMessage, MessageContainer, GzipPacked from ..tl.all_tlobjects import tlobjects -from ..tl.functions import InvokeAfterMsgRequest from ..tl.functions.auth import LogOutRequest from ..tl.types import ( MsgsAck, Pong, BadServerSalt, BadMsgNotification, FutureSalts, - MsgNewDetailedInfo, NewSessionCreated, MsgDetailedInfo + MsgNewDetailedInfo, MsgDetailedInfo, MsgsStateReq, MsgResendReq, + MsgsAllInfo, MsgsStateInfo, RpcError ) __log__ = logging.getLogger(__name__) @@ -56,7 +56,8 @@ class MtProtoSender: # receiving other request from the main thread (e.g. an update arrives # and we need to process it) we must ensure that only one is calling # receive at a given moment, since the receive step is fragile. - self._recv_lock = asyncio.Lock() + self._read_lock = asyncio.Lock(loop=self._loop) + self._write_lock = asyncio.Lock(loop=self._loop) # Requests (as msg_id: Message) sent waiting to be received self._pending_receive = {} @@ -73,11 +74,12 @@ class MtProtoSender: """ return self.connection.is_connected() - def disconnect(self): + def disconnect(self, clear_pendings=True): """Disconnects from the server.""" __log__.info('Disconnecting MtProtoSender...') self.connection.close() - self._clear_all_pending() + if clear_pendings: + self._clear_all_pending() # region Send and receive @@ -90,6 +92,7 @@ class MtProtoSender: :param ordered: whether the requests should be invoked in the order in which they appear or they can be executed in arbitrary order in the server. + :return: a list of msg_ids which are correspond to sent requests. """ if not utils.is_list_like(requests): requests = (requests,) @@ -111,6 +114,7 @@ class MtProtoSender: messages = [TLMessage(self.session, r) for r in requests] self._pending_receive.update({m.msg_id: m for m in messages}) + msg_ids = [m.msg_id for m in messages] __log__.debug('Sending requests with IDs: %s', ', '.join( '{}: {}'.format(m.request.__class__.__name__, m.msg_id) @@ -128,12 +132,18 @@ class MtProtoSender: m.container_msg_id = message.msg_id await self._send_message(message) + return msg_ids + + def forget_pendings(self, msg_ids): + for msg_id in msg_ids: + if msg_id in self._pending_receive: + del self._pending_receive[msg_id] async def _send_acknowledge(self, msg_id): """Sends a message acknowledge for the given msg_id.""" await self._send_message(TLMessage(self.session, MsgsAck([msg_id]))) - async def receive(self, update_state): + async def receive(self, updates_handler): """ Receives a single message from the connected endpoint. @@ -144,21 +154,13 @@ class MtProtoSender: Any unhandled object (likely updates) will be passed to update_state.process(TLObject). - :param update_state: - the UpdateState that will process all the received + :param updates_handler: + the handler that will process all the received Update and Updates objects. """ - if self._recv_lock.locked(): - with await self._recv_lock: - # Don't busy wait, acquire it but return because there's - # already a receive running and we don't want another one. - # It would lock until Telegram sent another update even if - # the current receive already received the expected response. - return - + await self._read_lock.acquire() try: - with await self._recv_lock: - body = await self.connection.recv() + body = await self.connection.recv() except (BufferError, InvalidChecksumError): # TODO BufferError, we should spot the cause... # "No more bytes left"; something wrong happened, clear @@ -172,11 +174,12 @@ class MtProtoSender: len(self._pending_receive)) self._clear_all_pending() return + finally: + self._read_lock.release() message, remote_msg_id, remote_seq = self._decode_msg(body) with BinaryReader(message) as reader: - await self._process_msg(remote_msg_id, remote_seq, reader, update_state) - await self._send_acknowledge(remote_msg_id) + await self._process_msg(remote_msg_id, remote_seq, reader, updates_handler) # endregion @@ -188,7 +191,11 @@ class MtProtoSender: :param message: the TLMessage to be sent. """ - await self.connection.send(helpers.pack_message(self.session, message)) + await self._write_lock.acquire() + try: + await self.connection.send(helpers.pack_message(self.session, message)) + finally: + self._write_lock.release() def _decode_msg(self, body): """ @@ -206,14 +213,14 @@ class MtProtoSender: with BinaryReader(body) as reader: return helpers.unpack_message(self.session, reader) - async def _process_msg(self, msg_id, sequence, reader, state): + async def _process_msg(self, msg_id, sequence, reader, updates_handler): """ Processes the message read from the network inside reader. :param msg_id: the ID of the message. :param sequence: the sequence of the message. :param reader: the BinaryReader that contains the message. - :param state: the current UpdateState. + :param updates_handler: the handler to process Update and Updates objects. :return: true if the message was handled correctly, false otherwise. """ # TODO Check salt, session_id and sequence_number @@ -224,15 +231,16 @@ class MtProtoSender: # These are a bit of special case, not yet generated by the code gen if code == 0xf35c6d01: # rpc_result, (response of an RPC call) __log__.debug('Processing Remote Procedure Call result') + await self._send_acknowledge(msg_id) return await self._handle_rpc_result(msg_id, sequence, reader) if code == MessageContainer.CONSTRUCTOR_ID: __log__.debug('Processing container result') - return await self._handle_container(msg_id, sequence, reader, state) + return await self._handle_container(msg_id, sequence, reader, updates_handler) if code == GzipPacked.CONSTRUCTOR_ID: __log__.debug('Processing gzipped result') - return await self._handle_gzip_packed(msg_id, sequence, reader, state) + return await self._handle_gzip_packed(msg_id, sequence, reader, updates_handler) if code not in tlobjects: __log__.warning( @@ -250,6 +258,14 @@ class MtProtoSender: if isinstance(obj, BadServerSalt): return await self._handle_bad_server_salt(msg_id, sequence, obj) + if isinstance(obj, (MsgsStateReq, MsgResendReq)): + # just answer we don't know anything + return await self._handle_msgs_state_forgotten(msg_id, sequence, obj) + + if isinstance(obj, MsgsAllInfo): + # not interesting now + return True + if isinstance(obj, BadMsgNotification): return await self._handle_bad_msg_notification(msg_id, sequence, obj) @@ -259,11 +275,8 @@ class MtProtoSender: if isinstance(obj, MsgNewDetailedInfo): return await self._handle_msg_new_detailed_info(msg_id, sequence, obj) - if isinstance(obj, NewSessionCreated): - return await self._handle_new_session_created(msg_id, sequence, obj) - if isinstance(obj, MsgsAck): # may handle the request we wanted - # Ignore every ack request *unless* when logging out, when it's + # Ignore every ack request *unless* when logging out, # when it seems to only make sense. We also need to set a non-None # result since Telegram doesn't send the response for these. for msg_id in obj.msg_ids: @@ -284,8 +297,9 @@ class MtProtoSender: # If the object isn't any of the above, then it should be an Update. self.session.process_entities(obj) - if state: - state.process(obj) + await self._send_acknowledge(msg_id) + if updates_handler: + updates_handler(obj) return True @@ -372,22 +386,24 @@ class MtProtoSender: return True - async def _handle_container(self, msg_id, sequence, reader, state): + async def _handle_container(self, msg_id, sequence, reader, updates_handler): """ Handles a MessageContainer response. :param msg_id: the ID of the message. :param sequence: the sequence of the message. :param reader: the reader containing the MessageContainer. + :param updates_handler: handler to handle Update and Updates objects. :return: true, as it always succeeds. """ + __log__.debug('Handling container') for inner_msg_id, _, inner_len in MessageContainer.iter_read(reader): begin_position = reader.tell_position() # Note that this code is IMPORTANT for skipping RPC results of # lost requests (i.e., ones from the previous connection session) try: - if not await self._process_msg(inner_msg_id, sequence, reader, state): + if not await self._process_msg(inner_msg_id, sequence, reader, updates_handler): reader.set_position(begin_position + inner_len) except: # If any error is raised, something went wrong; skip the packet @@ -406,7 +422,6 @@ class MtProtoSender: :return: true, as it always succeeds. """ self.session.salt = bad_salt.new_server_salt - self.session.save() # "the bad_server_salt response is received with the # correct salt, and the message is to be re-sent with it" @@ -414,6 +429,10 @@ class MtProtoSender: return True + async def _handle_msgs_state_forgotten(self, msg_id, sequence, req): + await self._send_message(TLMessage(self.session, MsgsStateInfo(msg_id, chr(1) * len(req.msg_ids)))) + return True + async def _handle_bad_msg_notification(self, msg_id, sequence, bad_msg): """ Handles a BadMessageError response. @@ -476,19 +495,6 @@ class MtProtoSender: await self._send_acknowledge(msg_new.answer_msg_id) return True - async def _handle_new_session_created(self, msg_id, sequence, new_session): - """ - Handles a NewSessionCreated response. - - :param msg_id: the ID of the message. - :param sequence: the sequence of the message. - :param reader: the reader containing the NewSessionCreated. - :return: true, as it always succeeds. - """ - self.session.salt = new_session.server_salt - # TODO https://goo.gl/LMyN7A - return True - async def _handle_rpc_result(self, msg_id, sequence, reader): """ Handles a RPCResult response. @@ -507,7 +513,7 @@ class MtProtoSender: __log__.debug('Received response for request with ID %d', request_id) request = self._pop_request(request_id) - if inner_code == 0x2144ca19: # RPC Error + if inner_code == RpcError.CONSTRUCTOR_ID: # RPC Error reader.seek(4) if self.session.report_errors and request: error = rpc_message_to_error( @@ -530,6 +536,7 @@ class MtProtoSender: return True # All contents were read okay elif request: + __log__.debug('Reading request response') if inner_code == GzipPacked.CONSTRUCTOR_ID: with BinaryReader(GzipPacked.read(reader)) as compressed_reader: request.on_response(compressed_reader) @@ -566,16 +573,18 @@ class MtProtoSender: ) return False - async def _handle_gzip_packed(self, msg_id, sequence, reader, state): + async def _handle_gzip_packed(self, msg_id, sequence, reader, updates_handler): """ Handles a GzipPacked response. :param msg_id: the ID of the message. :param sequence: the sequence of the message. :param reader: the reader containing the GzipPacked. + :param updates_handler: the handler to process Update and Updates objects. :return: the result of processing the packed message. """ + __log__.debug('Handling gzip packed data') with BinaryReader(GzipPacked.read(reader)) as compressed_reader: - return await self._process_msg(msg_id, sequence, compressed_reader, state) + return await self._process_msg(msg_id, sequence, compressed_reader, updates_handler) # endregion diff --git a/telethon/sessions/abstract.py b/telethon/sessions/abstract.py index fb0778bd..7e201c63 100644 --- a/telethon/sessions/abstract.py +++ b/telethon/sessions/abstract.py @@ -67,6 +67,22 @@ class Session(ABC): """ raise NotImplementedError + @property + @abstractmethod + def user_id(self): + """ + Returns an ``user_id`` which the session related to. + """ + raise NotImplementedError + + @user_id.setter + @abstractmethod + def user_id(self, value): + """ + Sets the ``user_id`` which the session related to. + """ + raise NotImplementedError + @abstractmethod def get_update_state(self, entity_id): """ @@ -94,7 +110,7 @@ class Session(ABC): """ @abstractmethod - def save(self): + async def save(self): """ Called whenever important properties change. It should make persist the relevant session information to disk. @@ -102,7 +118,7 @@ class Session(ABC): raise NotImplementedError @abstractmethod - def delete(self): + async def delete(self): """ Called upon client.log_out(). Should delete the stored information from disk since it's not valid anymore. @@ -125,7 +141,7 @@ class Session(ABC): raise NotImplementedError @abstractmethod - def get_input_entity(self, key): + async def get_input_entity(self, key): """ Turns the given key into an ``InputPeer`` (e.g. ``InputPeerUser``). The library uses this method whenever an ``InputPeer`` is needed @@ -135,7 +151,7 @@ class Session(ABC): raise NotImplementedError @abstractmethod - def cache_file(self, md5_digest, file_size, instance): + async def cache_file(self, md5_digest, file_size, instance): """ Caches the given file information persistently, so that it doesn't need to be re-uploaded in case the file is used again. @@ -146,7 +162,7 @@ class Session(ABC): raise NotImplementedError @abstractmethod - def get_file(self, md5_digest, file_size, cls): + async def get_file(self, md5_digest, file_size, cls): """ Returns an instance of ``cls`` if the ``md5_digest`` and ``file_size`` match an existing saved record. The class will either be an diff --git a/telethon/sessions/memory.py b/telethon/sessions/memory.py index acd09a77..49d5cc0f 100644 --- a/telethon/sessions/memory.py +++ b/telethon/sessions/memory.py @@ -32,6 +32,7 @@ class MemorySession(Session): self._server_address = None self._port = None self._auth_key = None + self._user_id = None self._files = {} self._entities = set() @@ -58,6 +59,14 @@ class MemorySession(Session): def auth_key(self, value): self._auth_key = value + @property + def user_id(self): + return self._user_id + + @user_id.setter + def user_id(self, value): + self._user_id = value + def get_update_state(self, entity_id): return self._update_states.get(entity_id, None) @@ -67,10 +76,10 @@ class MemorySession(Session): def close(self): pass - def save(self): + async def save(self): pass - def delete(self): + async def delete(self): pass def _entity_values_to_row(self, id, hash, username, phone, name): @@ -170,7 +179,7 @@ class MemorySession(Session): except StopIteration: pass - def get_input_entity(self, key): + async def get_input_entity(self, key): try: if key.SUBCLASS_OF_ID in (0xc91c90b6, 0xe669bf46, 0x40f202fd): # hex(crc32(b'InputPeer', b'InputUser' and b'InputChannel')) @@ -215,14 +224,14 @@ class MemorySession(Session): else: raise ValueError('Could not find input entity with key ', key) - def cache_file(self, md5_digest, file_size, instance): + async def cache_file(self, md5_digest, file_size, instance): if not isinstance(instance, (InputDocument, InputPhoto)): raise TypeError('Cannot cache %s instance' % type(instance)) key = (md5_digest, file_size, _SentFileType.from_type(instance)) value = (instance.id, instance.access_hash) self._files[key] = value - def get_file(self, md5_digest, file_size, cls): + async def get_file(self, md5_digest, file_size, cls): key = (md5_digest, file_size, _SentFileType.from_type(cls)) try: return cls(self._files[key]) diff --git a/telethon/sessions/sqlite.py b/telethon/sessions/sqlite.py index 8f18c21e..e2ad013e 100644 --- a/telethon/sessions/sqlite.py +++ b/telethon/sessions/sqlite.py @@ -213,7 +213,7 @@ class SQLiteSession(MemorySession): )) c.close() - def get_update_state(self, entity_id): + async def get_update_state(self, entity_id): c = self._cursor() row = c.execute('select pts, qts, date, seq from update_state ' 'where id = ?', (entity_id,)).fetchone() @@ -223,7 +223,7 @@ class SQLiteSession(MemorySession): date = datetime.datetime.utcfromtimestamp(date) return types.updates.State(pts, qts, date, seq, unread_count=0) - def set_update_state(self, entity_id, state): + async def set_update_state(self, entity_id, state): c = self._cursor() c.execute('insert or replace into update_state values (?,?,?,?,?)', (entity_id, state.pts, state.qts, @@ -231,7 +231,7 @@ class SQLiteSession(MemorySession): c.close() self.save() - def save(self): + async def save(self): """Saves the current session object as session_user_id.session""" self._conn.commit() @@ -248,7 +248,7 @@ class SQLiteSession(MemorySession): self._conn.close() self._conn = None - def delete(self): + async def delete(self): """Deletes the current session file""" if self.filename == ':memory:': return True @@ -319,7 +319,7 @@ class SQLiteSession(MemorySession): # File processing - def get_file(self, md5_digest, file_size, cls): + async def get_file(self, md5_digest, file_size, cls): tuple_ = self._cursor().execute( 'select id, hash from sent_files ' 'where md5_digest = ? and file_size = ? and type = ?', @@ -329,7 +329,7 @@ class SQLiteSession(MemorySession): # Both allowed classes have (id, access_hash) as parameters return cls(tuple_[0], tuple_[1]) - def cache_file(self, md5_digest, file_size, instance): + async def cache_file(self, md5_digest, file_size, instance): if not isinstance(instance, (InputDocument, InputPhoto)): raise TypeError('Cannot cache %s instance' % type(instance)) diff --git a/telethon/telegram_bare_client.py b/telethon/telegram_bare_client.py index a6efc709..290d8c78 100644 --- a/telethon/telegram_bare_client.py +++ b/telethon/telegram_bare_client.py @@ -1,8 +1,8 @@ import asyncio import logging import os -from asyncio import Lock -from datetime import timedelta +from asyncio import Lock, Event +from datetime import timedelta, datetime import platform from . import version, utils from .crypto import rsa @@ -13,7 +13,7 @@ from .errors import ( RpcCallFailError ) from .network import authenticator, MtProtoSender, ConnectionTcpFull -from .sessions import Session, SQLiteSession +from .sessions import Session from .tl import TLObject from .tl.all_tlobjects import LAYER from .tl.functions import ( @@ -25,10 +25,17 @@ from .tl.functions.auth import ( from .tl.functions.help import ( GetCdnConfigRequest, GetConfigRequest ) -from .tl.functions.updates import GetStateRequest +from .tl.functions.updates import GetStateRequest, GetDifferenceRequest +from .tl.types import ( + Pong, PeerUser, PeerChat, Message, Updates, UpdateShort, UpdateNewChannelMessage, UpdateEditChannelMessage, + UpdateDeleteChannelMessages, UpdateChannelTooLong, UpdateNewMessage, NewSessionCreated, UpdatesTooLong, + UpdateShortSentMessage, MessageEmpty, UpdateShortMessage, UpdateShortChatMessage, UpdatesCombined +) from .tl.types.auth import ExportedAuthorization -from .update_state import UpdateState +from .tl.types.messages import AffectedMessages, AffectedHistory +from .tl.types.updates import DifferenceEmpty, DifferenceTooLong, DifferenceSlice +MAX_TIMEOUT = 15 # in seconds DEFAULT_DC_ID = 4 DEFAULT_IPV4_IP = '149.154.167.51' DEFAULT_IPV6_IP = '[2001:67c:4e8:f002::a]' @@ -71,8 +78,11 @@ class TelegramBareClient: use_ipv6=False, proxy=None, timeout=timedelta(seconds=5), + ping_delay=timedelta(minutes=1), + update_handler=None, + unauthorized_handler=None, loop=None, - report_errors=True, + report_errors=None, device_model=None, system_version=None, app_version=None, @@ -87,12 +97,8 @@ class TelegramBareClient: self._use_ipv6 = use_ipv6 # Determine what session object we have - if isinstance(session, str) or session is None: - session = SQLiteSession(session) - elif not isinstance(session, Session): - raise TypeError( - 'The given session must be a str or a Session instance.' - ) + if not isinstance(session, Session): + raise TypeError('The given session must be a Session instance.') self._loop = loop if loop else asyncio.get_event_loop() @@ -105,7 +111,8 @@ class TelegramBareClient: DEFAULT_PORT ) - session.report_errors = report_errors + if report_errors is not None: + session.report_errors = report_errors self.session = session self.api_id = int(api_id) self.api_hash = api_hash @@ -128,10 +135,6 @@ class TelegramBareClient: # expensive operation. self._exported_sessions = {} - # This member will process updates if enabled. - # One may change self.updates.enabled at any later point. - self.updates = UpdateState(self._loop) - # Used on connection - the user may modify these and reconnect system = platform.uname() self.device_model = device_model or system.system or 'Unknown' @@ -140,6 +143,12 @@ class TelegramBareClient: self.lang_code = lang_code self.system_lang_code = system_lang_code + self._state = None + self._sync_loading = False + self.update_handler = update_handler + self.unauthorized_handler = unauthorized_handler + self._last_update = datetime.now() + # Despite the state of the real connection, keep track of whether # the user has explicitly called .connect() or .disconnect() here. # This information is required by the read thread, who will be the @@ -147,90 +156,72 @@ class TelegramBareClient: # doesn't explicitly call .disconnect(), thus telling it to stop # retrying. The main thread, knowing there is a background thread # attempting reconnection as soon as it happens, will just sleep. - self._user_connected = False - - # Save whether the user is authorized here (a.k.a. logged in) - self._authorized = None # None = We don't know yet - - # The first request must be in invokeWithLayer(initConnection(X)). - # See https://core.telegram.org/api/invoking#saving-client-info. - self._first_request = True - + self._user_connected = Event(loop=self._loop) + self._authorized = False + self._shutdown = False self._recv_loop = None self._ping_loop = None - self._state_loop = None - self._idling = asyncio.Event() + self._reconnection_loop = None - # Default PingRequest delay - self._ping_delay = timedelta(minutes=1) - # Also have another delay for GetStateRequest. - # - # If the connection is kept alive for long without invoking any - # high level request the server simply stops sending updates. - # TODO maybe we can have ._last_request instead if any req works? - self._state_delay = timedelta(hours=1) + if isinstance(ping_delay, timedelta): + self._ping_delay = ping_delay.seconds + elif isinstance(ping_delay, (int, float)): + self._ping_delay = float(ping_delay) + else: + raise TypeError('Invalid timeout type', type(timeout)) - # endregion - - # region Connecting - - async def connect(self, _sync_updates=True): - """Connects to the Telegram servers, executing authentication if - required. Note that authenticating to the Telegram servers is - not the same as authenticating the desired user itself, which - may require a call (or several) to 'sign_in' for the first time. - - Note that the optional parameters are meant for internal use. - - If '_sync_updates', sync_updates() will be called and a - second thread will be started if necessary. Note that this - will FAIL if the client is not connected to the user's - native data center, raising a "UserMigrateError", and - calling .disconnect() in the process. - """ - __log__.info('Connecting to %s:%d...', - self.session.server_address, self.session.port) + def __del__(self): + self.disconnect() + async def connect(self): try: - await self._sender.connect() - __log__.info('Connection success!') - - # Connection was successful! Try syncing the update state - # UNLESS '_sync_updates' is False (we probably are in - # another data center and this would raise UserMigrateError) - # to also assert whether the user is logged in or not. - self._user_connected = True - if self._authorized is None and _sync_updates: + if not self._sender.is_connected(): + await self._sender.connect() + if not self.session.auth_key: try: - await self.sync_updates() - await self._set_connected_and_authorized() + self.session.auth_key, self.session.time_offset = \ + await authenticator.do_authentication(self._sender.connection) + await self.session.save() + except BrokenAuthKeyError: + self._user_connected.clear() + return False + + if TelegramBareClient._config is None: + TelegramBareClient._config = await self(self._wrap_init_connection(GetConfigRequest())) + + if not self._authorized: + try: + self._state = await self(self._wrap_init_connection(GetStateRequest())) + self._authorized = True except UnauthorizedError: self._authorized = False - elif self._authorized: - await self._set_connected_and_authorized() + self.run_loops() + self._user_connected.set() return True except TypeNotFoundError as e: # This is fine, probably layer migration __log__.warning('Connection failed, got unexpected type with ID ' '%s. Migrating?', hex(e.invalid_constructor_id)) - self.disconnect() - return await self.connect(_sync_updates=_sync_updates) + self.disconnect(False) + return await self.connect() except AuthKeyError as e: # As of late March 2018 there were two AUTH_KEY_DUPLICATED # reports. Retrying with a clean auth_key should fix this. - __log__.warning('Auth key error %s. Clearing it and retrying.', e) - self.disconnect() - self.session.auth_key = None - self.session.save() - return self.connect(_sync_updates=_sync_updates) + if not self._authorized: + __log__.warning('Auth key error %s. Clearing it and retrying.', e) + self.disconnect(False) + self.session.auth_key = None + return self.connect() + else: + raise except (RPCError, ConnectionError) as e: # Probably errors from the previous session, ignore them __log__.error('Connection failed due to %s', e) - self.disconnect() + self.disconnect(False) return False def is_connected(self): @@ -249,24 +240,11 @@ class TelegramBareClient: query=query )) - def disconnect(self): + def disconnect(self, shutdown=True): """Disconnects from the Telegram server""" - __log__.info('Disconnecting...') - self._user_connected = False - self._sender.disconnect() - if self._recv_loop: - self._recv_loop.cancel() - self._recv_loop = None - if self._ping_loop: - self._ping_loop.cancel() - self._ping_loop = None - if self._state_loop: - self._state_loop.cancel() - self._state_loop = None - # TODO Shall we clear the _exported_sessions, or may be reused? - self._first_request = True # On reconnect it will be first again - self.session.set_update_state(0, self.updates.get_update_state(0)) - self.session.close() + self._shutdown = shutdown + self._user_connected.clear() + self._sender.disconnect(clear_pendings=shutdown) async def _reconnect(self, new_dc=None): """If 'new_dc' is not set, only a call to .connect() will be made @@ -277,32 +255,23 @@ class TelegramBareClient: current data center, clears the auth key for the old DC, and connects to the new data center. """ - if new_dc is None: - # Assume we are disconnected due to some error, so connect again - try: - if self.is_connected(): - __log__.info('Reconnection aborted: already connected') - return True + await self._reconnect_lock.acquire() + try: + # Another thread may have connected again, so check that first + if self.is_connected() and new_dc is None: + return True - __log__.info('Attempting reconnection...') - return await self.connect() - except ConnectionResetError as e: - __log__.warning('Reconnection failed due to %s', e) - return False - else: - # Since we're reconnecting possibly due to a UserMigrateError, - # we need to first know the Data Centers we can connect to. Do - # that before disconnecting. - dc = await self._get_dc(new_dc) - __log__.info('Reconnecting to new data center %s', dc) + if new_dc is not None: + dc = await self._get_dc(new_dc) + self.disconnect(False) + self.session.set_dc(dc.id, dc.ip_address, dc.port) + await self.session.save() - self.session.set_dc(dc.id, dc.ip_address, dc.port) - # auth_key's are associated with a server, which has now changed - # so it's not valid anymore. Set to None to force recreating it. - self.session.auth_key = None - self.session.save() - self.disconnect() return await self.connect() + except (ConnectionResetError, TimeoutError): + return False + finally: + self._reconnect_lock.release() def set_proxy(self, proxy): """Change the proxy used by the connections. @@ -348,7 +317,7 @@ class TelegramBareClient: """ # Thanks badoualy/kotlogram on /telegram/api/DefaultTelegramClient.kt # for clearly showing how to export the authorization! ^^ - session = self._exported_sessions.get(dc_id) + session = self._exported_sessions.get(dc_id, None) if session: export_auth = None # Already bound with the auth key else: @@ -377,7 +346,7 @@ class TelegramBareClient: timeout=self._sender.connection.get_timeout(), loop=self._loop ) - await client.connect(_sync_updates=False) + await client.connect() if isinstance(export_auth, ExportedAuthorization): await client(ImportAuthorizationRequest( id=export_auth.id, bytes=export_auth.bytes @@ -385,12 +354,11 @@ class TelegramBareClient: elif export_auth is not None: __log__.warning('Unknown export auth type %s', export_auth) - client._authorized = True # We exported the auth, so we got auth return client async def _get_cdn_client(self, cdn_redirect): """Similar to ._get_exported_client, but for CDNs""" - session = self._exported_sessions.get(cdn_redirect.dc_id) + session = self._exported_sessions.get(cdn_redirect.dc_id, None) if not session: dc = await self._get_dc(cdn_redirect.dc_id, cdn=True) session = self.session.clone() @@ -410,8 +378,7 @@ class TelegramBareClient: # We won't be calling GetConfigRequest because it's only called # when needed by ._get_dc, and also it's static so it's likely # set already. Avoid invoking non-CDN methods by not syncing updates. - await client.connect(_sync_updates=False) - client._authorized = self._authorized + await client.connect() return client # endregion @@ -461,100 +428,70 @@ class TelegramBareClient: which = '{} requests ({})'.format( len(request), [type(x).__name__ for x in request]) + is_ping = any(isinstance(x, PingRequest) for x in request) + msg_ids = [] + __log__.debug('Invoking %s', which) - call_receive = \ - not self._idling.is_set() or self._reconnect_lock.locked() + try: + for retry in range(retries): + result = None + for sub_retry in range(retries): + msg_ids, result = await self._invoke(retry, request, ordered, msg_ids) + if msg_ids: + break + if not self.is_connected(): + break + __log__.error('Subretry %d is failed' % sub_retry) + if result is None: + if not is_ping: + try: + pong = await self(PingRequest(int.from_bytes(os.urandom(8), 'big', signed=True)), retries=1) + if isinstance(pong, Pong): + __log__.error('Connection is live, but no answer on %d retry' % retry) + continue + except RuntimeError: + pass # continue to reconnect + if self.is_connected() and (retry + 1) % 2 == 0: + __log__.error('Force disconnect on %d retry' % retry) + self.disconnect(False) + self._sender.forget_pendings(msg_ids) + msg_ids = [] + if not self.is_connected(): + __log__.error('Pause before new retry on %d retry' % retry) + await asyncio.sleep(retry + 1, loop=self._loop) + else: + return result[0] if single else result + finally: + self._sender.forget_pendings(msg_ids) - for retry in range(retries): - result = await self._invoke(call_receive, retry, request, - ordered=ordered) - if result is not None: - return result[0] if single else result - - log = __log__.info if retry == 0 else __log__.warning - log('Invoking %s failed %d times, connecting again and retrying', - which, retry + 1) - - await asyncio.sleep(1) - if not self._reconnect_lock.locked(): - with await self._reconnect_lock: - await self._reconnect() - - raise RuntimeError('Number of retries reached 0 for {}.'.format( - which - )) + raise RuntimeError('Number of retries is exceeded for {}.'.format(which)) # Let people use client.invoke(SomeRequest()) instead client(...) invoke = __call__ - async def _invoke(self, call_receive, retry, requests, ordered=False): + async def _invoke(self, retry, requests, ordered, msg_ids): try: - # Ensure that we start with no previous errors (i.e. resending) - for x in requests: - x.rpc_error = None + if not msg_ids: + msg_ids = await self._sender.send(requests, ordered) - if not self.session.auth_key: - __log__.info('Need to generate new auth key before invoking') - self._first_request = True - self.session.auth_key, self.session.time_offset = \ - await authenticator.do_authentication(self._sender.connection) + # Ensure that we start with no previous errors (i.e. resending) + for x in requests: + x.rpc_error = None - if self._first_request: - __log__.info('Initializing a new connection while invoking') - if len(requests) == 1: - requests = [self._wrap_init_connection(requests[0])] - else: - # We need a SINGLE request (like GetConfig) to init conn. - # Once that's done, the N original requests will be - # invoked. - TelegramBareClient._config = await self( - self._wrap_init_connection(GetConfigRequest()) - ) - - await self._sender.send(requests, ordered=ordered) - - if not call_receive: - await asyncio.wait( - list(map(lambda x: x.confirm_received.wait(), requests)), - timeout=self._sender.connection.get_timeout(), - loop=self._loop - ) + if self._user_connected.is_set(): + fut = asyncio.gather(*list(map(lambda x: x.confirm_received.wait(), requests)), loop=self._loop) + self._loop.call_later(self._sender.connection.get_timeout(), fut.cancel) + await fut else: while not all(x.confirm_received.is_set() for x in requests): - await self._sender.receive(update_state=self.updates) - - except BrokenAuthKeyError: - __log__.error('Authorization key seems broken and was invalid!') - self.session.auth_key = None - - except TypeNotFoundError as e: - # Only occurs when we call receive. May happen when - # we need to reconnect to another DC on login and - # Telegram somehow sends old objects (like configOld) - self._first_request = True - __log__.warning('Read unknown TLObject code ({}). ' - 'Setting again first_request flag.' - .format(hex(e.invalid_constructor_id))) - - except TimeoutError: - __log__.warning('Invoking timed out') # We will just retry + await self._sender.receive(self._updates_handler) + except (TimeoutError, asyncio.CancelledError): + __log__.error('Timeout on %d retry' % retry) except ConnectionResetError as e: - __log__.warning('Connection was reset while invoking') - if self._user_connected: - # Server disconnected us, __call__ will try reconnecting. - try: - self._sender.disconnect() - except: - pass - - return None - else: - # User never called .connect(), so raise this error. - raise RuntimeError('Tried to invoke without .connect()') from e - - # Clear the flag if we got this far - self._first_request = False + if self._shutdown: + raise + __log__.error('Connection reset on %d retry: %r' % (retry, e)) try: raise next(x.rpc_error for x in requests if x.rpc_error) @@ -562,15 +499,32 @@ class TelegramBareClient: if any(x.result is None for x in requests): # "A container may only be accepted or # rejected by the other party as a whole." - return None + return msg_ids, None - return [x.result for x in requests] + for req in requests: + if isinstance(req.result, TLObject) and req.result.SUBCLASS_OF_ID == Updates.SUBCLASS_OF_ID: + self._updates_handler(req.result, False, False) + if isinstance(req.result, (AffectedMessages, AffectedHistory)): # due to affect to pts + self._updates_handler(UpdateShort(req.result, None), False, False) - except (PhoneMigrateError, NetworkMigrateError, - UserMigrateError) as e: + return msg_ids, [x.result for x in requests] + + except (PhoneMigrateError, NetworkMigrateError, UserMigrateError) as e: + if isinstance(e, (PhoneMigrateError, NetworkMigrateError)): + if self._authorized: + raise + else: + self.session.auth_key = None # Force creating new auth_key + + __log__.error( + 'DC error when invoking request, ' + 'attempting to reconnect at DC {}'.format(e.new_dc) + ) await self._reconnect(new_dc=e.new_dc) - return await self._invoke(call_receive, retry, requests) + self._sender.forget_pendings(msg_ids) + msg_ids = [] + return msg_ids, None except (ServerError, RpcCallFailError) as e: # Telegram is having some issues, just retry @@ -582,7 +536,16 @@ class TelegramBareClient: raise await asyncio.sleep(e.seconds, loop=self._loop) - return None + return msg_ids, None + + except UnauthorizedError: + if self._authorized: + __log__.error('Authorization has lost') + self._authorized = False + self.disconnect() + if self.unauthorized_handler: + await self.unauthorized_handler(self) + raise # Some really basic functionality @@ -602,74 +565,216 @@ class TelegramBareClient: # region Updates handling - async def sync_updates(self): - """Synchronizes self.updates to their initial state. Will be - called automatically on connection if self.updates.enabled = True, - otherwise it should be called manually after enabling updates. - """ - self.updates.process(await self(GetStateRequest())) + async def _handle_update(self, update, seq_start, seq, date, do_get_diff, do_handlers, users=(), chats=()): + if isinstance(update, (UpdateNewChannelMessage, UpdateEditChannelMessage, + UpdateDeleteChannelMessages, UpdateChannelTooLong)): + # TODO: channel updates have their own pts sequences, so requires individual pts'es + return # ignore channel updates to keep pts in the main _state in the correct state + if hasattr(update, 'pts'): + new_pts = self._state.pts + getattr(update, 'pts_count', 0) + if new_pts < update.pts: + __log__.debug('Have got a hole between pts => waiting 0.5 sec') + await asyncio.sleep(0.5, loop=self._loop) + if new_pts < update.pts: + if do_get_diff and not self._sync_loading: + __log__.debug('The hole between pts has not disappeared => going to get differences') + self._sync_loading = True + asyncio.ensure_future(self._get_difference(), loop=self._loop) + return + if update.pts > self._state.pts: + self._state.pts = update.pts + elif getattr(update, 'pts_count', 0) > 0: + __log__.debug('Have got the duplicate update (basing on pts) => ignoring') + return + elif hasattr(update, 'qts'): + if self._state.qts + 1 < update.qts: + __log__.debug('Have got a hole between qts => waiting 0.5 sec') + await asyncio.sleep(0.5, loop=self._loop) + if self._state.qts + 1 < update.qts: + if do_get_diff and not self._sync_loading: + __log__.debug('The hole between qts has not disappeared => going to get differences') + self._sync_loading = True + asyncio.ensure_future(self._get_difference(), loop=self._loop) + return + if update.qts > self._state.qts: + self._state.qts = update.qts + else: + __log__.debug('Have got the duplicate update (basing on qts) => ignoring') + return + elif seq > 0: + if seq_start > self._state.seq + 1: + __log__.debug('Have got a hole between seq => waiting 0.5 sec') + await asyncio.sleep(0.5, loop=self._loop) + if seq_start > self._state.seq + 1: + if do_get_diff and not self._sync_loading: + __log__.debug('The hole between seq has not disappeared => going to get differences') + self._sync_loading = True + asyncio.ensure_future(self._get_difference(), loop=self._loop) + return + self._state.seq = seq + self._state.date = max(self._state.date, date) - # endregion + if do_handlers and self.update_handler: + asyncio.ensure_future(self.update_handler(self, update, users, chats), loop=self._loop) - # Constant read + async def _get_difference(self): + self._sync_loading = True + try: + difference = await self(GetDifferenceRequest(self._state.pts, self._state.date, self._state.qts)) + if isinstance(difference, DifferenceEmpty): + __log__.debug('Have got DifferenceEmpty => just update seq and date') + self._state.seq = difference.seq + self._state.date = difference.date + return + if isinstance(difference, DifferenceTooLong): + __log__.debug('Have got DifferenceTooLong => update pts and try again') + self._state.pts = difference.pts + asyncio.ensure_future(self._get_difference(), loop=self._loop) + return + __log__.debug('Preparing updates from differences') + self._state = difference.intermediate_state \ + if isinstance(difference, DifferenceSlice) else difference.state + messages = [UpdateNewMessage(msg, self._state.pts, 0) for msg in difference.new_messages] + self._updates_handler( + Updates(messages + difference.other_updates, + difference.users, difference.chats, self._state.date, self._state.seq), + False + ) + if isinstance(difference, DifferenceSlice): + asyncio.ensure_future(self._get_difference(), loop=self._loop) + except ConnectionResetError: # it happens on unauth due to _get_difference is often on the background + pass + except Exception as e: + __log__.exception('Exception on _get_difference: %r', e) + finally: + self._sync_loading = False - # This is async so that the overrided version in TelegramClient can be - # async without problems. - async def _set_connected_and_authorized(self): - self._authorized = True + # TODO: Some of logic was moved from MtProtoSender and probably must be moved back. + def _updates_handler(self, updates, do_get_diff=True, do_handlers=True): + if do_get_diff: + self._last_update = datetime.now() + if isinstance(updates, NewSessionCreated): + self.session.salt = updates.server_salt + if self._state is None: + return False # not ready yet + if self._sync_loading and do_get_diff: + return False # ignore all if in sync except from difference (do_get_diff = False) + if isinstance(updates, (NewSessionCreated, UpdatesTooLong)): + if do_get_diff: # to prevent possible loops + __log__.debug('Have got %s => going to get differences', type(updates)) + self._sync_loading = True + asyncio.ensure_future(self._get_difference(), loop=self._loop) + return False + + seq = getattr(updates, 'seq', 0) + seq_start = getattr(updates, 'seq_start', seq) + date = getattr(updates, 'date', self._state.date) + + if isinstance(updates, UpdateShort): + asyncio.ensure_future( + self._handle_update(updates.update, seq_start, seq, date, do_get_diff, do_handlers), + loop=self._loop + ) + return True + + if isinstance(updates, UpdateShortSentMessage): + asyncio.ensure_future(self._handle_update( + UpdateNewMessage(MessageEmpty(updates.id), updates.pts, updates.pts_count), + seq_start, seq, date, do_get_diff, do_handlers + ), loop=self._loop) + return True + + if isinstance(updates, (UpdateShortMessage, UpdateShortChatMessage)): + from_id = getattr(updates, 'from_id', self.session.user_id) + to_id = updates.user_id if isinstance(updates, UpdateShortMessage) else updates.chat_id + if not updates.out: + from_id, to_id = to_id, from_id + to_id = PeerUser(to_id) if isinstance(updates, UpdateShortMessage) else PeerChat(to_id) + message = Message( + id=updates.id, to_id=to_id, date=updates.date, message=updates.message, out=updates.out, + mentioned=updates.mentioned, media_unread=updates.media_unread, silent=updates.silent, + from_id=from_id, fwd_from=updates.fwd_from, via_bot_id=updates.via_bot_id, + reply_to_msg_id=updates.reply_to_msg_id, entities=updates.entities + ) + asyncio.ensure_future(self._handle_update( + UpdateNewMessage(message, updates.pts, updates.pts_count), + seq_start, seq, date, do_get_diff, do_handlers + ), loop=self._loop) + return True + + if isinstance(updates, (Updates, UpdatesCombined)): + for upd in updates.updates: + asyncio.ensure_future( + self._handle_update(upd, seq_start, seq, date, do_get_diff, do_handlers, + updates.users, updates.chats), + loop=self._loop + ) + return True + + if do_get_diff: # to prevent possible loops + __log__.debug('Have got unsupported type of updates: %s => going to get differences', type(updates)) + self._sync_loading = True + asyncio.ensure_future(self._get_difference(), loop=self._loop) + + return False + + def run_loops(self): if self._recv_loop is None: self._recv_loop = asyncio.ensure_future(self._recv_loop_impl(), loop=self._loop) if self._ping_loop is None: self._ping_loop = asyncio.ensure_future(self._ping_loop_impl(), loop=self._loop) - if self._state_loop is None: - self._state_loop = asyncio.ensure_future(self._state_loop_impl(), loop=self._loop) async def _ping_loop_impl(self): - while self._user_connected: - await self(PingRequest(int.from_bytes(os.urandom(8), 'big', signed=True))) - await asyncio.sleep(self._ping_delay.seconds, loop=self._loop) + while True: + if self._shutdown: + break + try: + await self._user_connected.wait() + await self(PingRequest(int.from_bytes(os.urandom(8), 'big', signed=True))) + await asyncio.sleep(self._ping_delay, loop=self._loop) + except RuntimeError: + pass # Can be not happy due to connection problems + except asyncio.CancelledError: + break + except: + self._ping_loop = None + raise self._ping_loop = None - async def _state_loop_impl(self): - while self._user_connected: - await asyncio.sleep(self._state_delay.seconds, loop=self._loop) - await self._sender.send(GetStateRequest()) - async def _recv_loop_impl(self): - __log__.info('Starting to wait for items from the network') - self._idling.set() - need_reconnect = False - while self._user_connected: + timeout = 1 + while True: + if self._shutdown: + break try: - if need_reconnect: - __log__.info('Attempting reconnection from read loop') - need_reconnect = False - with await self._reconnect_lock: - while self._user_connected and not await self._reconnect(): - # Retry forever, this is instant messaging - await asyncio.sleep(0.1, loop=self._loop) - - # Telegram seems to kick us every 1024 items received - # from the network not considering things like bad salt. - # We must execute some *high level* request (that's not - # a ping) if we want to receive updates again. - # TODO Test if getDifference works too (better alternative) - await self._sender.send(GetStateRequest()) - - __log__.debug('Receiving items from the network...') - await self._sender.receive(update_state=self.updates) + if self._user_connected.is_set(): + if self._authorized and datetime.now() - self._last_update > timedelta(minutes=15): + __log__.debug('No updates for 15 minutes => going to get differences') + self._last_update = datetime.now() + self._sync_loading = True + asyncio.ensure_future(self._get_difference(), loop=self._loop) + await self._sender.receive(self._updates_handler) + else: + if await self._reconnect(): + __log__.info('Connection has established') + timeout = 1 + else: + await asyncio.sleep(timeout, loop=self._loop) + timeout = min(timeout * 2, MAX_TIMEOUT) except TimeoutError: # No problem. - __log__.debug('Receiving items from the network timed out') - except ConnectionError: - need_reconnect = True - __log__.error('Connection was reset while receiving items') - await asyncio.sleep(1, loop=self._loop) - except: - self._idling.clear() - raise - - self._idling.clear() - __log__.info('Connection closed by the user, not reading anymore') + pass + except ConnectionResetError as error: + __log__.info('Connection reset error in recv loop: %r' % error) + self._user_connected.clear() + except asyncio.CancelledError: + self.disconnect() + break + except Exception as error: + # Unknown exception, pass it to the main thread + __log__.exception('[ERROR: %r] on the read loop, please report', error) + self._recv_loop = None + if self._shutdown and self._ping_loop: + self._ping_loop.cancel() # endregion diff --git a/telethon/telegram_client.py b/telethon/telegram_client.py index 5f46551f..bda21e06 100644 --- a/telethon/telegram_client.py +++ b/telethon/telegram_client.py @@ -510,7 +510,7 @@ class TelegramClient(TelegramBareClient): return False self.disconnect() - self.session.delete() + await self.session.delete() self._authorized = False return True @@ -1805,7 +1805,7 @@ class TelegramClient(TelegramBareClient): to_cache = utils.get_input_photo(msg.media.photo) else: to_cache = utils.get_input_document(msg.media.document) - self.session.cache_file(md5, size, to_cache) + await self.session.cache_file(md5, size, to_cache) return msg @@ -1849,7 +1849,7 @@ class TelegramClient(TelegramBareClient): input_photo = utils.get_input_photo((await self(UploadMediaRequest( entity, media=InputMediaUploadedPhoto(fh) ))).photo) - self.session.cache_file(fh.md5, fh.size, input_photo) + await self.session.cache_file(fh.md5, fh.size, input_photo) fh = input_photo if captions: @@ -1957,7 +1957,7 @@ class TelegramClient(TelegramBareClient): file = stream.read() hash_md5.update(file) if use_cache: - cached = self.session.get_file( + cached = await self.session.get_file( hash_md5.digest(), file_size, cls=use_cache ) if cached: @@ -2122,7 +2122,7 @@ class TelegramClient(TelegramBareClient): media, file, date, progress_callback ) elif isinstance(media, MessageMediaContact): - return await self._download_contact( + return self._download_contact( media, file ) @@ -2472,7 +2472,7 @@ class TelegramClient(TelegramBareClient): be passed instead. """ - self.updates.handler = self._on_handler + self.update_handler = self._on_handler if isinstance(event, type): event = event() elif not event: @@ -2555,7 +2555,7 @@ class TelegramClient(TelegramBareClient): # infinite loop here (so check against old pts to stop) break - self.updates.process(Updates( + self._updates_handler(Updates( users=d.users, chats=d.chats, date=state.date, @@ -2573,10 +2573,6 @@ class TelegramClient(TelegramBareClient): # region Small utilities to make users' life easier - async def _set_connected_and_authorized(self): - await super()._set_connected_and_authorized() - await self._check_events_pending_resolve() - async def get_entity(self, entity): """ Turns the given entity into a valid Telegram user or chat. @@ -2694,7 +2690,7 @@ class TelegramClient(TelegramBareClient): try: # Nobody with this username, maybe it's an exact name/title return await self.get_entity( - self.session.get_input_entity(string)) + await self.session.get_input_entity(string)) except ValueError: pass @@ -2729,7 +2725,7 @@ class TelegramClient(TelegramBareClient): try: # First try to get the entity from cache, otherwise figure it out - return self.session.get_input_entity(peer) + return await self.session.get_input_entity(peer) except ValueError: pass diff --git a/telethon/tl/tlobject.py b/telethon/tl/tlobject.py index ede7a8a1..5637ba19 100644 --- a/telethon/tl/tlobject.py +++ b/telethon/tl/tlobject.py @@ -192,3 +192,6 @@ class TLObject: @classmethod def from_reader(cls, reader): return TLObject() + + def __repr__(self): + return self.__str__() diff --git a/telethon/update_state.py b/telethon/update_state.py deleted file mode 100644 index 4f0e38ed..00000000 --- a/telethon/update_state.py +++ /dev/null @@ -1,65 +0,0 @@ -import asyncio -import itertools -import logging -from datetime import datetime - -from . import utils -from .tl import types as tl - -__log__ = logging.getLogger(__name__) - - -class UpdateState: - """ - Used to hold the current state of processed updates. - To retrieve an update, :meth:`poll` should be called. - """ - WORKER_POLL_TIMEOUT = 5.0 # Avoid waiting forever on the workers - - def __init__(self, loop=None): - self.handler = None - self._loop = loop if loop else asyncio.get_event_loop() - - # https://core.telegram.org/api/updates - self._state = tl.updates.State(0, 0, datetime.now(), 0, 0) - - def handle_update(self, update): - if self.handler: - asyncio.ensure_future(self.handler(update), loop=self._loop) - - def get_update_state(self, entity_id): - """Gets the updates.State corresponding to the given entity or 0.""" - return self._state - - def process(self, update): - """Processes an update object. This method is normally called by - the library itself. - """ - if isinstance(update, tl.updates.State): - __log__.debug('Saved new updates state') - self._state = update - return # Nothing else to be done - - if hasattr(update, 'pts'): - self._state.pts = update.pts - - # After running the script for over an hour and receiving over - # 1000 updates, the only duplicates received were users going - # online or offline. We can trust the server until new reports. - # This should only be used as read-only. - if isinstance(update, tl.UpdateShort): - update.update._entities = {} - self.handle_update(update.update) - # Expand "Updates" into "Update", and pass these to callbacks. - # Since .users and .chats have already been processed, we - # don't need to care about those either. - elif isinstance(update, (tl.Updates, tl.UpdatesCombined)): - entities = {utils.get_peer_id(x): x for x in - itertools.chain(update.users, update.chats)} - for u in update.updates: - u._entities = entities - self.handle_update(u) - # TODO Handle "tl.UpdatesTooLong" - else: - update._entities = {} - self.handle_update(update)