mirror of
				https://github.com/LonamiWebs/Telethon.git
				synced 2025-10-31 16:07:44 +03:00 
			
		
		
		
	Update handlers works; it also seems stable
This commit is contained in:
		
							parent
							
								
									917665852d
								
							
						
					
					
						commit
						780e0ceddf
					
				|  | @ -5,13 +5,12 @@ import socket | |||
| from datetime import timedelta | ||||
| from io import BytesIO, BufferedWriter | ||||
| 
 | ||||
| loop = asyncio.get_event_loop() | ||||
| 
 | ||||
| 
 | ||||
| class TcpClient: | ||||
|     def __init__(self, proxy=None, timeout=timedelta(seconds=5)): | ||||
|     def __init__(self, proxy=None, timeout=timedelta(seconds=5), loop=None): | ||||
|         self.proxy = proxy | ||||
|         self._socket = None | ||||
|         self._loop = loop if loop else asyncio.get_event_loop() | ||||
| 
 | ||||
|         if isinstance(timeout, timedelta): | ||||
|             self.timeout = timeout.seconds | ||||
|  | @ -31,7 +30,7 @@ class TcpClient: | |||
|             else:  # tuple, list, etc. | ||||
|                 self._socket.set_proxy(*self.proxy) | ||||
| 
 | ||||
|         self._socket.settimeout(self.timeout) | ||||
|         self._socket.setblocking(False) | ||||
| 
 | ||||
|     async def connect(self, ip, port): | ||||
|         """Connects to the specified IP and port number. | ||||
|  | @ -42,20 +41,27 @@ class TcpClient: | |||
|         else: | ||||
|             mode, address = socket.AF_INET, (ip, port) | ||||
| 
 | ||||
|         timeout = 1 | ||||
|         while True: | ||||
|             try: | ||||
|                 while not self._socket: | ||||
|                 if not self._socket: | ||||
|                     self._recreate_socket(mode) | ||||
| 
 | ||||
|                 await loop.sock_connect(self._socket, address) | ||||
|                 await self._loop.sock_connect(self._socket, address) | ||||
|                 break  # Successful connection, stop retrying to connect | ||||
|             except ConnectionError: | ||||
|                 self._socket = None | ||||
|                 await asyncio.sleep(min(timeout, 15)) | ||||
|                 timeout *= 2 | ||||
|             except OSError as e: | ||||
|                 # There are some errors that we know how to handle, and | ||||
|                 # the loop will allow us to retry | ||||
|                 if e.errno == errno.EBADF: | ||||
|                 if e.errno in [errno.EBADF, errno.ENOTSOCK, errno.EINVAL]: | ||||
|                     # Bad file descriptor, i.e. socket was closed, set it | ||||
|                     # to none to recreate it on the next iteration | ||||
|                     self._socket = None | ||||
|                     await asyncio.sleep(min(timeout, 15)) | ||||
|                     timeout *= 2 | ||||
|                 else: | ||||
|                     raise | ||||
| 
 | ||||
|  | @ -81,13 +87,14 @@ class TcpClient: | |||
|             raise ConnectionResetError() | ||||
| 
 | ||||
|         try: | ||||
|             await loop.sock_sendall(self._socket, data) | ||||
|         except socket.timeout as e: | ||||
|             await asyncio.wait_for(self._loop.sock_sendall(self._socket, data), | ||||
|                                    timeout=self.timeout, loop=self._loop) | ||||
|         except asyncio.TimeoutError as e: | ||||
|             raise TimeoutError() from e | ||||
|         except BrokenPipeError: | ||||
|             self._raise_connection_reset() | ||||
|         except OSError as e: | ||||
|             if e.errno == errno.EBADF: | ||||
|             if e.errno in [errno.EBADF, errno.ENOTSOCK, errno.ENETUNREACH, errno.EINVAL, errno.ENOTCONN]: | ||||
|                 self._raise_connection_reset() | ||||
|             else: | ||||
|                 raise | ||||
|  | @ -104,11 +111,12 @@ class TcpClient: | |||
|             bytes_left = size | ||||
|             while bytes_left != 0: | ||||
|                 try: | ||||
|                     partial = await loop.sock_recv(self._socket, bytes_left) | ||||
|                 except socket.timeout as e: | ||||
|                     partial = await asyncio.wait_for(self._loop.sock_recv(self._socket, bytes_left), | ||||
|                                                      timeout=self.timeout, loop=self._loop) | ||||
|                 except asyncio.TimeoutError as e: | ||||
|                     raise TimeoutError() from e | ||||
|                 except OSError as e: | ||||
|                     if e.errno == errno.EBADF or e.errno == errno.ENOTSOCK: | ||||
|                     if e.errno in [errno.EBADF, errno.ENOTSOCK, errno.ENETUNREACH, errno.EINVAL, errno.ENOTCONN]: | ||||
|                         self._raise_connection_reset() | ||||
|                     else: | ||||
|                         raise | ||||
|  |  | |||
|  | @ -43,13 +43,13 @@ class Connection: | |||
|     """ | ||||
| 
 | ||||
|     def __init__(self, mode=ConnectionMode.TCP_FULL, | ||||
|                  proxy=None, timeout=timedelta(seconds=5)): | ||||
|                  proxy=None, timeout=timedelta(seconds=5), loop=None): | ||||
|         self._mode = mode | ||||
|         self._send_counter = 0 | ||||
|         self._aes_encrypt, self._aes_decrypt = None, None | ||||
| 
 | ||||
|         # TODO Rename "TcpClient" as some sort of generic socket? | ||||
|         self.conn = TcpClient(proxy=proxy, timeout=timeout) | ||||
|         self.conn = TcpClient(proxy=proxy, timeout=timeout, loop=loop) | ||||
| 
 | ||||
|         # Sending messages | ||||
|         if mode == ConnectionMode.TCP_FULL: | ||||
|  | @ -206,7 +206,7 @@ class Connection: | |||
|         return await self.conn.read(length) | ||||
| 
 | ||||
|     async def _read_obfuscated(self, length): | ||||
|         return await self._aes_decrypt.encrypt(self.conn.read(length)) | ||||
|         return self._aes_decrypt.encrypt(await self.conn.read(length)) | ||||
| 
 | ||||
|     # endregion | ||||
| 
 | ||||
|  |  | |||
|  | @ -1,6 +1,8 @@ | |||
| import gzip | ||||
| import logging | ||||
| import struct | ||||
| import asyncio | ||||
| from asyncio import Event | ||||
| 
 | ||||
| from .. import helpers as utils | ||||
| from ..crypto import AES | ||||
|  | @ -30,17 +32,15 @@ class MtProtoSender: | |||
|                   in parallel, so thread-safety (hence locking) isn't needed. | ||||
|     """ | ||||
| 
 | ||||
|     def __init__(self, session, connection): | ||||
|     def __init__(self, session, connection, loop=None): | ||||
|         """Creates a new MtProtoSender configured to send messages through | ||||
|            'connection' and using the parameters from 'session'. | ||||
|         """ | ||||
|         self.session = session | ||||
|         self.connection = connection | ||||
|         self._loop = loop if loop else asyncio.get_event_loop() | ||||
|         self._logger = logging.getLogger(__name__) | ||||
| 
 | ||||
|         # Message IDs that need confirmation | ||||
|         self._need_confirmation = [] | ||||
| 
 | ||||
|         # Requests (as msg_id: Message) sent waiting to be received | ||||
|         self._pending_receive = {} | ||||
| 
 | ||||
|  | @ -54,12 +54,11 @@ class MtProtoSender: | |||
|     def disconnect(self): | ||||
|         """Disconnects from the server""" | ||||
|         self.connection.close() | ||||
|         self._need_confirmation.clear() | ||||
|         self._clear_all_pending() | ||||
| 
 | ||||
|     def clone(self): | ||||
|         """Creates a copy of this MtProtoSender as a new connection""" | ||||
|         return MtProtoSender(self.session, self.connection.clone()) | ||||
|         return MtProtoSender(self.session, self.connection.clone(), self._loop) | ||||
| 
 | ||||
|     # region Send and receive | ||||
| 
 | ||||
|  | @ -67,21 +66,23 @@ class MtProtoSender: | |||
|         """Sends the specified MTProtoRequest, previously sending any message | ||||
|            which needed confirmation.""" | ||||
| 
 | ||||
|         # Prepare the event of every request | ||||
|         for r in requests: | ||||
|             if r.confirm_received is None: | ||||
|                 r.confirm_received = Event(loop=self._loop) | ||||
|             else: | ||||
|                 r.confirm_received.clear() | ||||
| 
 | ||||
|         # Finally send our packed request(s) | ||||
|         messages = [TLMessage(self.session, r) for r in requests] | ||||
|         self._pending_receive.update({m.msg_id: m for m in messages}) | ||||
| 
 | ||||
|         # Pack everything in the same container if we need to send AckRequests | ||||
|         if self._need_confirmation: | ||||
|             messages.append( | ||||
|                 TLMessage(self.session, MsgsAck(self._need_confirmation)) | ||||
|             ) | ||||
|             self._need_confirmation.clear() | ||||
| 
 | ||||
|         if len(messages) == 1: | ||||
|             message = messages[0] | ||||
|         else: | ||||
|             message = TLMessage(self.session, MessageContainer(messages)) | ||||
|             for m in messages: | ||||
|                 m.container_msg_id = message.msg_id | ||||
| 
 | ||||
|         await self._send_message(message) | ||||
| 
 | ||||
|  | @ -115,6 +116,7 @@ class MtProtoSender: | |||
|         message, remote_msg_id, remote_seq = self._decode_msg(body) | ||||
|         with BinaryReader(message) as reader: | ||||
|             await self._process_msg(remote_msg_id, remote_seq, reader, update_state) | ||||
|             await self._send_acknowledge(remote_msg_id) | ||||
| 
 | ||||
|     # endregion | ||||
| 
 | ||||
|  | @ -174,7 +176,6 @@ class MtProtoSender: | |||
|         """ | ||||
| 
 | ||||
|         # TODO Check salt, session_id and sequence_number | ||||
|         self._need_confirmation.append(msg_id) | ||||
| 
 | ||||
|         code = reader.read_int(signed=False) | ||||
|         reader.seek(-4) | ||||
|  | @ -210,14 +211,14 @@ class MtProtoSender: | |||
|         if code == MsgsAck.CONSTRUCTOR_ID:  # may handle the request we wanted | ||||
|             ack = reader.tgread_object() | ||||
|             assert isinstance(ack, MsgsAck) | ||||
|             # Ignore every ack request *unless* when logging out, when it's | ||||
|             # Ignore every ack request *unless* when logging out, | ||||
|             # when it seems to only make sense. We also need to set a non-None | ||||
|             # result since Telegram doesn't send the response for these. | ||||
|             for msg_id in ack.msg_ids: | ||||
|                 r = self._pop_request_of_type(msg_id, LogOutRequest) | ||||
|                 if r: | ||||
|                     r.result = True  # Telegram won't send this value | ||||
|                     r.confirm_received() | ||||
|                     r.confirm_received.set() | ||||
|                     self._logger.debug('Message ack confirmed', r) | ||||
| 
 | ||||
|             return True | ||||
|  | @ -259,11 +260,29 @@ class MtProtoSender: | |||
|         if message and isinstance(message.request, t): | ||||
|             return self._pending_receive.pop(msg_id).request | ||||
| 
 | ||||
|     def _pop_requests_of_container(self, container_msg_id): | ||||
|         msgs = [msg for msg in self._pending_receive.values() if msg.container_msg_id == container_msg_id] | ||||
|         requests = [msg.request for msg in msgs] | ||||
|         for msg in msgs: | ||||
|             self._pending_receive.pop(msg.msg_id, None) | ||||
|         return requests | ||||
| 
 | ||||
|     def _clear_all_pending(self): | ||||
|         for r in self._pending_receive.values(): | ||||
|             r.confirm_received.set() | ||||
|             r.request.confirm_received.set() | ||||
|         self._pending_receive.clear() | ||||
| 
 | ||||
|     async def _resend_request(self, msg_id): | ||||
|         request = self._pop_request(msg_id) | ||||
|         if request: | ||||
|             self._logger.debug('requests is about to resend') | ||||
|             await self.send(request) | ||||
|             return | ||||
|         requests = self._pop_requests_of_container(msg_id) | ||||
|         if requests: | ||||
|             self._logger.debug('container of requests is about to resend') | ||||
|             await self.send(*requests) | ||||
| 
 | ||||
|     async def _handle_pong(self, msg_id, sequence, reader): | ||||
|         self._logger.debug('Handling pong') | ||||
|         pong = reader.tgread_object() | ||||
|  | @ -303,10 +322,9 @@ class MtProtoSender: | |||
|         self.session.salt = struct.unpack( | ||||
|             '<Q', struct.pack('<q', bad_salt.new_server_salt) | ||||
|         )[0] | ||||
|         self.session.save() | ||||
| 
 | ||||
|         request = self._pop_request(bad_salt.bad_msg_id) | ||||
|         if request: | ||||
|             await self.send(request) | ||||
|         await self._resend_request(bad_salt.bad_msg_id) | ||||
| 
 | ||||
|         return True | ||||
| 
 | ||||
|  | @ -322,15 +340,18 @@ class MtProtoSender: | |||
|             self.session.update_time_offset(correct_msg_id=msg_id) | ||||
|             self._logger.debug('Read Bad Message error: ' + str(error)) | ||||
|             self._logger.debug('Attempting to use the correct time offset.') | ||||
|             await self._resend_request(bad_msg.bad_msg_id) | ||||
|             return True | ||||
|         elif bad_msg.error_code == 32: | ||||
|             # msg_seqno too low, so just pump it up by some "large" amount | ||||
|             # TODO A better fix would be to start with a new fresh session ID | ||||
|             self.session._sequence += 64 | ||||
|             await self._resend_request(bad_msg.bad_msg_id) | ||||
|             return True | ||||
|         elif bad_msg.error_code == 33: | ||||
|             # msg_seqno too high never seems to happen but just in case | ||||
|             self.session._sequence -= 16 | ||||
|             await self._resend_request(bad_msg.bad_msg_id) | ||||
|             return True | ||||
|         else: | ||||
|             raise error | ||||
|  | @ -341,7 +362,6 @@ class MtProtoSender: | |||
| 
 | ||||
|         # TODO For now, simply ack msg_new.answer_msg_id | ||||
|         # Relevant tdesktop source code: https://goo.gl/VvpCC6 | ||||
|         await self._send_acknowledge(msg_new.answer_msg_id) | ||||
|         return True | ||||
| 
 | ||||
|     async def _handle_msg_new_detailed_info(self, msg_id, sequence, reader): | ||||
|  | @ -350,7 +370,6 @@ class MtProtoSender: | |||
| 
 | ||||
|         # TODO For now, simply ack msg_new.answer_msg_id | ||||
|         # Relevant tdesktop source code: https://goo.gl/G7DPsR | ||||
|         await self._send_acknowledge(msg_new.answer_msg_id) | ||||
|         return True | ||||
| 
 | ||||
|     async def _handle_new_session_created(self, msg_id, sequence, reader): | ||||
|  | @ -378,9 +397,6 @@ class MtProtoSender: | |||
|                     reader.read_int(), reader.tgread_string() | ||||
|                 ) | ||||
| 
 | ||||
|             # Acknowledge that we received the error | ||||
|             await self._send_acknowledge(request_id) | ||||
| 
 | ||||
|             if request: | ||||
|                 request.rpc_error = error | ||||
|                 request.confirm_received.set() | ||||
|  |  | |||
|  | @ -1,10 +1,10 @@ | |||
| import logging | ||||
| import os | ||||
| import warnings | ||||
| import asyncio | ||||
| from datetime import timedelta, datetime | ||||
| from hashlib import md5 | ||||
| from io import BytesIO | ||||
| from time import sleep | ||||
| from asyncio import Lock | ||||
| 
 | ||||
| from . import helpers as utils | ||||
| from .crypto import rsa, CdnDecrypter | ||||
|  | @ -17,7 +17,7 @@ from .network import authenticator, MtProtoSender, Connection, ConnectionMode | |||
| from .tl import TLObject, Session | ||||
| from .tl.all_tlobjects import LAYER | ||||
| from .tl.functions import ( | ||||
|     InitConnectionRequest, InvokeWithLayerRequest | ||||
|     InitConnectionRequest, InvokeWithLayerRequest, PingRequest | ||||
| ) | ||||
| from .tl.functions.auth import ( | ||||
|     ImportAuthorizationRequest, ExportAuthorizationRequest | ||||
|  | @ -67,6 +67,7 @@ class TelegramBareClient: | |||
|                  connection_mode=ConnectionMode.TCP_FULL, | ||||
|                  proxy=None, | ||||
|                  timeout=timedelta(seconds=5), | ||||
|                  loop=None, | ||||
|                  **kwargs): | ||||
|         """Refer to TelegramClient.__init__ for docs on this method""" | ||||
|         if not api_id or not api_hash: | ||||
|  | @ -82,6 +83,8 @@ class TelegramBareClient: | |||
|                 'The given session must be a str or a Session instance.' | ||||
|             ) | ||||
| 
 | ||||
|         self._loop = loop if loop else asyncio.get_event_loop() | ||||
| 
 | ||||
|         self.session = session | ||||
|         self.api_id = int(api_id) | ||||
|         self.api_hash = api_hash | ||||
|  | @ -92,12 +95,18 @@ class TelegramBareClient: | |||
|         # that calls .connect(). Every other thread will spawn a new | ||||
|         # temporary connection. The connection on this one is always | ||||
|         # kept open so Telegram can send us updates. | ||||
|         self._sender = MtProtoSender(self.session, Connection( | ||||
|             mode=connection_mode, proxy=proxy, timeout=timeout | ||||
|         )) | ||||
|         self._sender = MtProtoSender( | ||||
|             self.session, | ||||
|             Connection(mode=connection_mode, proxy=proxy, timeout=timeout, loop=self._loop), | ||||
|             self._loop | ||||
|         ) | ||||
| 
 | ||||
|         self._logger = logging.getLogger(__name__) | ||||
| 
 | ||||
|         # Two coroutines may be calling reconnect() when the connection is lost, | ||||
|         # we only want one to actually perform the reconnection. | ||||
|         self._reconnect_lock = Lock(loop=self._loop) | ||||
| 
 | ||||
|         # Cache "exported" sessions as 'dc_id: Session' not to recreate | ||||
|         # them all the time since generating a new key is a relatively | ||||
|         # expensive operation. | ||||
|  | @ -105,7 +114,7 @@ class TelegramBareClient: | |||
| 
 | ||||
|         # This member will process updates if enabled. | ||||
|         # One may change self.updates.enabled at any later point. | ||||
|         self.updates = UpdateState(workers=None) | ||||
|         self.updates = UpdateState(self._loop) | ||||
| 
 | ||||
|         # Used on connection - the user may modify these and reconnect | ||||
|         kwargs['app_version'] = kwargs.get('app_version', self.__version__) | ||||
|  | @ -129,10 +138,11 @@ class TelegramBareClient: | |||
|         # Uploaded files cache so subsequent calls are instant | ||||
|         self._upload_cache = {} | ||||
| 
 | ||||
|         # Default PingRequest delay | ||||
|         self._last_ping = datetime.now() | ||||
|         self._ping_delay = timedelta(minutes=1) | ||||
|         self._recv_loop = None | ||||
|         self._ping_loop = None | ||||
| 
 | ||||
|         # Default PingRequest delay | ||||
|         self._ping_delay = timedelta(minutes=1) | ||||
| 
 | ||||
|     # endregion | ||||
| 
 | ||||
|  | @ -167,6 +177,7 @@ class TelegramBareClient: | |||
|                     self.session.auth_key, self.session.time_offset = \ | ||||
|                         await authenticator.do_authentication(self._sender.connection) | ||||
|                 except BrokenAuthKeyError: | ||||
|                     self._user_connected = False | ||||
|                     return False | ||||
| 
 | ||||
|                 self.session.layer = LAYER | ||||
|  | @ -198,12 +209,12 @@ class TelegramBareClient: | |||
|             # another data center and this would raise UserMigrateError) | ||||
|             # to also assert whether the user is logged in or not. | ||||
|             self._user_connected = True | ||||
|             if _sync_updates and not _cdn: | ||||
|             if _sync_updates and not _cdn and not self._authorized: | ||||
|                 try: | ||||
|                     await self.sync_updates() | ||||
|                     self._set_connected_and_authorized() | ||||
|                 except UnauthorizedError: | ||||
|                     self._authorized = False | ||||
|                     pass | ||||
| 
 | ||||
|             return True | ||||
| 
 | ||||
|  | @ -211,7 +222,7 @@ class TelegramBareClient: | |||
|             # This is fine, probably layer migration | ||||
|             self._logger.debug('Found invalid item, probably migrating', e) | ||||
|             self.disconnect() | ||||
|             return self.connect( | ||||
|             return await self.connect( | ||||
|                 _exported_auth=_exported_auth, | ||||
|                 _sync_updates=_sync_updates, | ||||
|                 _cdn=_cdn | ||||
|  | @ -261,7 +272,17 @@ class TelegramBareClient: | |||
|         """ | ||||
|         if new_dc is None: | ||||
|             # Assume we are disconnected due to some error, so connect again | ||||
|             return await self.connect() | ||||
|             try: | ||||
|                 await self._reconnect_lock.acquire() | ||||
|                 # Another thread may have connected again, so check that first | ||||
|                 if self.is_connected(): | ||||
|                     return True | ||||
| 
 | ||||
|                 return await self.connect() | ||||
|             except ConnectionResetError: | ||||
|                 return False | ||||
|             finally: | ||||
|                 self._reconnect_lock.release() | ||||
|         else: | ||||
|             self.disconnect() | ||||
|             self.session.auth_key = None  # Force creating new auth_key | ||||
|  | @ -337,7 +358,8 @@ class TelegramBareClient: | |||
|         client = TelegramBareClient( | ||||
|             session, self.api_id, self.api_hash, | ||||
|             proxy=self._sender.connection.conn.proxy, | ||||
|             timeout=self._sender.connection.get_timeout() | ||||
|             timeout=self._sender.connection.get_timeout(), | ||||
|             loop=self._loop | ||||
|         ) | ||||
|         await client.connect(_exported_auth=export_auth, _sync_updates=False) | ||||
|         client._authorized = True  # We exported the auth, so we got auth | ||||
|  | @ -356,7 +378,8 @@ class TelegramBareClient: | |||
|         client = TelegramBareClient( | ||||
|             session, self.api_id, self.api_hash, | ||||
|             proxy=self._sender.connection.conn.proxy, | ||||
|             timeout=self._sender.connection.get_timeout() | ||||
|             timeout=self._sender.connection.get_timeout(), | ||||
|             loop=self._loop | ||||
|         ) | ||||
| 
 | ||||
|         # This will make use of the new RSA keys for this specific CDN. | ||||
|  | @ -381,55 +404,52 @@ class TelegramBareClient: | |||
|                    x.content_related for x in requests): | ||||
|             raise ValueError('You can only invoke requests, not types!') | ||||
| 
 | ||||
|         # TODO Determine the sender to be used (main or a new connection) | ||||
|         sender = self._sender  # .clone(), .connect() | ||||
|         # We should call receive from this thread if there's no background | ||||
|         # thread reading or if the server disconnected us and we're trying | ||||
|         # to reconnect. This is because the read thread may either be | ||||
|         # locked also trying to reconnect or we may be said thread already. | ||||
|         call_receive = self._recv_loop is None | ||||
| 
 | ||||
|         try: | ||||
|             for _ in range(retries): | ||||
|                 result = await self._invoke(sender, *requests) | ||||
|                 if result is not None: | ||||
|                     return result | ||||
|         for retry in range(retries): | ||||
|             result = await self._invoke(call_receive, retry, *requests) | ||||
|             if result is not None: | ||||
|                 return result | ||||
| 
 | ||||
|             raise ValueError('Number of retries reached 0.') | ||||
|         finally: | ||||
|             if sender != self._sender: | ||||
|                 sender.disconnect()  # Close temporary connections | ||||
|         return None | ||||
| 
 | ||||
|     # Let people use client.invoke(SomeRequest()) instead client(...) | ||||
|     invoke = __call__ | ||||
| 
 | ||||
|     async def _invoke(self, sender, *requests): | ||||
|     async def _invoke(self, call_receive, retry, *requests): | ||||
|         try: | ||||
|             # Ensure that we start with no previous errors (i.e. resending) | ||||
|             for x in requests: | ||||
|                 x.confirm_received.clear() | ||||
|                 x.rpc_error = None | ||||
| 
 | ||||
|             await sender.send(*requests) | ||||
|             while not all(x.confirm_received.is_set() for x in requests): | ||||
|                 await sender.receive(update_state=self.updates) | ||||
|             await self._sender.send(*requests) | ||||
| 
 | ||||
|         except TimeoutError: | ||||
|             pass  # We will just retry | ||||
|             if not call_receive: | ||||
|                 await asyncio.wait( | ||||
|                     list(map(lambda x: x.confirm_received.wait(), requests)), | ||||
|                     timeout=self._sender.connection.get_timeout(), | ||||
|                     loop=self._loop | ||||
|                 ) | ||||
|             else: | ||||
|                 while not all(x.confirm_received.is_set() for x in requests): | ||||
|                     await self._sender.receive(update_state=self.updates) | ||||
| 
 | ||||
|         except ConnectionResetError: | ||||
|             if not self._user_connected: | ||||
|                 # Only attempt reconnecting if we're authorized | ||||
|             if not self._user_connected or self._reconnect_lock.locked(): | ||||
|                 # Only attempt reconnecting if the user called connect and not | ||||
|                 # reconnecting already. | ||||
|                 raise | ||||
| 
 | ||||
|             self._logger.debug('Server disconnected us. Reconnecting and ' | ||||
|                                'resending request...') | ||||
| 
 | ||||
|             if sender != self._sender: | ||||
|                 # TODO Try reconnecting forever too? | ||||
|                 await sender.connect() | ||||
|             else: | ||||
|                 while self._user_connected and not await self._reconnect(): | ||||
|                     sleep(0.1)  # Retry forever until we can send the request | ||||
| 
 | ||||
|         finally: | ||||
|             if sender != self._sender: | ||||
|                 sender.disconnect() | ||||
|                                'resending request... (%d)' % retry) | ||||
|             await self._reconnect() | ||||
|             if not self._sender.is_connected(): | ||||
|                 await asyncio.sleep(retry + 1, loop=self._loop) | ||||
|             return None | ||||
| 
 | ||||
|         try: | ||||
|             raise next(x.rpc_error for x in requests if x.rpc_error) | ||||
|  | @ -452,7 +472,7 @@ class TelegramBareClient: | |||
|             ) | ||||
| 
 | ||||
|             await self._reconnect(new_dc=e.new_dc) | ||||
|             return await self._invoke(sender, *requests) | ||||
|             return None | ||||
| 
 | ||||
|         except ServerError as e: | ||||
|             # Telegram is having some issues, just retry | ||||
|  | @ -467,7 +487,8 @@ class TelegramBareClient: | |||
|             self._logger.debug( | ||||
|                 'Sleep of %d seconds below threshold, sleeping' % e.seconds | ||||
|             ) | ||||
|             sleep(e.seconds) | ||||
|             await asyncio.sleep(e.seconds, loop=self._loop) | ||||
|             return None | ||||
| 
 | ||||
|     # Some really basic functionality | ||||
| 
 | ||||
|  | @ -670,16 +691,13 @@ class TelegramBareClient: | |||
|         """ | ||||
|         self.updates.process(await self(GetStateRequest())) | ||||
| 
 | ||||
|     def add_update_handler(self, handler): | ||||
|     async def add_update_handler(self, handler): | ||||
|         """Adds an update handler (a function which takes a TLObject, | ||||
|           an update, as its parameter) and listens for updates""" | ||||
|         if not self.updates.get_workers: | ||||
|             warnings.warn("There are no update workers running, so adding an update handler will have no effect.") | ||||
| 
 | ||||
|         sync = not self.updates.handlers | ||||
|         self.updates.handlers.append(handler) | ||||
|         if sync: | ||||
|             self.sync_updates() | ||||
|             await self.sync_updates() | ||||
| 
 | ||||
|     def remove_update_handler(self, handler): | ||||
|         self.updates.handlers.remove(handler) | ||||
|  | @ -693,6 +711,63 @@ class TelegramBareClient: | |||
| 
 | ||||
|     def _set_connected_and_authorized(self): | ||||
|         self._authorized = True | ||||
|         # TODO self.updates.setup_workers() | ||||
|         if self._recv_loop is None: | ||||
|             self._recv_loop = asyncio.ensure_future(self._recv_loop_impl(), loop=self._loop) | ||||
|         if self._ping_loop is None: | ||||
|             self._ping_loop = asyncio.ensure_future(self._ping_loop_impl(), loop=self._loop) | ||||
| 
 | ||||
|     async def _ping_loop_impl(self): | ||||
|         while self._user_connected: | ||||
|             await self(PingRequest(int.from_bytes(os.urandom(8), 'big', signed=True))) | ||||
|             await asyncio.sleep(self._ping_delay.seconds, loop=self._loop) | ||||
|         self._ping_loop = None | ||||
| 
 | ||||
|     async def _recv_loop_impl(self): | ||||
|         need_reconnect = False | ||||
|         timeout = 1 | ||||
|         while self._user_connected: | ||||
|             try: | ||||
|                 if need_reconnect: | ||||
|                     need_reconnect = False | ||||
|                     while self._user_connected and not await self._reconnect(): | ||||
|                         await asyncio.sleep(0.1, loop=self._loop)  # Retry forever, this is instant messaging | ||||
| 
 | ||||
|                 await self._sender.receive(update_state=self.updates) | ||||
|             except TimeoutError: | ||||
|                 # No problem. | ||||
|                 pass | ||||
|             except ConnectionError as error: | ||||
|                 self._logger.debug(error) | ||||
|                 need_reconnect = True | ||||
|                 await asyncio.sleep(min(timeout, 15), loop=self._loop) | ||||
|                 timeout *= 2 | ||||
|             except Exception as error: | ||||
|                 # Unknown exception, pass it to the main thread | ||||
|                 self._logger.debug( | ||||
|                     '[ERROR] Unknown error on the read loop, please report', | ||||
|                     error | ||||
|                 ) | ||||
| 
 | ||||
|                 try: | ||||
|                     import socks | ||||
|                     if isinstance(error, ( | ||||
|                             socks.GeneralProxyError, socks.ProxyConnectionError | ||||
|                     )): | ||||
|                         # This is a known error, and it's not related to | ||||
|                         # Telegram but rather to the proxy. Disconnect and | ||||
|                         # hand it over to the main thread. | ||||
|                         self._background_error = error | ||||
|                         self.disconnect() | ||||
|                         break | ||||
|                 except ImportError: | ||||
|                     "Not using PySocks, so it can't be a socket error" | ||||
| 
 | ||||
|                 # If something strange happens we don't want to enter an | ||||
|                 # infinite loop where all we do is raise an exception, so | ||||
|                 # add a little sleep to avoid the CPU usage going mad. | ||||
|                 await asyncio.sleep(0.1, loop=self._loop) | ||||
|                 break | ||||
|             timeout = 1 | ||||
|         self._recv_loop = None | ||||
| 
 | ||||
|     # endregion | ||||
|  |  | |||
|  | @ -61,6 +61,7 @@ class TelegramClient(TelegramBareClient): | |||
|                  connection_mode=ConnectionMode.TCP_FULL, | ||||
|                  proxy=None, | ||||
|                  timeout=timedelta(seconds=5), | ||||
|                  loop=None, | ||||
|                  **kwargs): | ||||
|         """Initializes the Telegram client with the specified API ID and Hash. | ||||
| 
 | ||||
|  | @ -87,6 +88,7 @@ class TelegramClient(TelegramBareClient): | |||
|             connection_mode=connection_mode, | ||||
|             proxy=proxy, | ||||
|             timeout=timeout, | ||||
|             loop=loop, | ||||
|             **kwargs | ||||
|         ) | ||||
| 
 | ||||
|  | @ -104,8 +106,9 @@ class TelegramClient(TelegramBareClient): | |||
|         """Sends a code request to the specified phone number""" | ||||
|         phone = EntityDatabase.parse_phone(phone) or self._phone | ||||
|         result = await self(SendCodeRequest(phone, self.api_id, self.api_hash)) | ||||
|         self._phone = phone | ||||
|         self._phone_code_hash = result.phone_code_hash | ||||
|         if result: | ||||
|             self._phone = phone | ||||
|             self._phone_code_hash = result.phone_code_hash | ||||
|         return result | ||||
| 
 | ||||
|     async def sign_in(self, phone=None, code=None, | ||||
|  | @ -169,8 +172,10 @@ class TelegramClient(TelegramBareClient): | |||
|                 'and a password only if an RPCError was raised before.' | ||||
|             ) | ||||
| 
 | ||||
|         self._set_connected_and_authorized() | ||||
|         return result.user | ||||
|         if result: | ||||
|             self._set_connected_and_authorized() | ||||
|             return result.user | ||||
|         return result | ||||
| 
 | ||||
|     async def sign_up(self, code, first_name, last_name=''): | ||||
|         """Signs up to Telegram. Make sure you sent a code request first!""" | ||||
|  | @ -182,8 +187,10 @@ class TelegramClient(TelegramBareClient): | |||
|             last_name=last_name | ||||
|         )) | ||||
| 
 | ||||
|         self._set_connected_and_authorized() | ||||
|         return result.user | ||||
|         if result: | ||||
|             self._set_connected_and_authorized() | ||||
|             return result.user | ||||
|         return result | ||||
| 
 | ||||
|     async def log_out(self): | ||||
|         """Logs out and deletes the current session. | ||||
|  | @ -239,7 +246,7 @@ class TelegramClient(TelegramBareClient): | |||
|                 offset_peer=offset_peer, | ||||
|                 limit=need if need < float('inf') else 0 | ||||
|             )) | ||||
|             if not r.dialogs: | ||||
|             if not r or not r.dialogs: | ||||
|                 break | ||||
| 
 | ||||
|             for d in r.dialogs: | ||||
|  | @ -288,10 +295,12 @@ class TelegramClient(TelegramBareClient): | |||
|         :return List[telethon.tl.custom.Draft]: A list of open drafts | ||||
|         """ | ||||
|         response = await self(GetAllDraftsRequest()) | ||||
|         self.session.process_entities(response) | ||||
|         self.session.generate_sequence(response.seq) | ||||
|         drafts = [Draft._from_update(self, u) for u in response.updates] | ||||
|         return drafts | ||||
|         if response: | ||||
|             self.session.process_entities(response) | ||||
|             self.session.generate_sequence(response.seq) | ||||
|             drafts = [Draft._from_update(self, u) for u in response.updates] | ||||
|             return drafts | ||||
|         return response | ||||
| 
 | ||||
|     async def send_message(self, | ||||
|                            entity, | ||||
|  | @ -313,6 +322,9 @@ class TelegramClient(TelegramBareClient): | |||
|             reply_to_msg_id=self._get_reply_to(reply_to) | ||||
|         ) | ||||
|         result = await self(request) | ||||
|         if not result: | ||||
|             return result | ||||
| 
 | ||||
|         if isinstance(result, UpdateShortSentMessage): | ||||
|             return Message( | ||||
|                 id=result.id, | ||||
|  | @ -407,6 +419,8 @@ class TelegramClient(TelegramBareClient): | |||
|             min_id=min_id, | ||||
|             add_offset=add_offset | ||||
|         )) | ||||
|         if not result: | ||||
|             return result | ||||
| 
 | ||||
|         # The result may be a messages slice (not all messages were retrieved) | ||||
|         # or simply a messages TLObject. In the later case, no "count" | ||||
|  |  | |||
|  | @ -11,6 +11,12 @@ class MessageContainer(TLObject): | |||
|         self.content_related = False | ||||
|         self.messages = messages | ||||
| 
 | ||||
|     def to_dict(self, recursive=True): | ||||
|         return { | ||||
|             'content_related': self.content_related, | ||||
|             'messages': self.messages, | ||||
|         } | ||||
| 
 | ||||
|     def to_bytes(self): | ||||
|         return struct.pack( | ||||
|             '<Ii', MessageContainer.CONSTRUCTOR_ID, len(self.messages) | ||||
|  | @ -25,3 +31,9 @@ class MessageContainer(TLObject): | |||
|             inner_sequence = reader.read_int() | ||||
|             inner_length = reader.read_int() | ||||
|             yield inner_msg_id, inner_sequence, inner_length | ||||
| 
 | ||||
|     def __str__(self): | ||||
|         return TLObject.pretty_format(self) | ||||
| 
 | ||||
|     def stringify(self): | ||||
|         return TLObject.pretty_format(self, indent=0) | ||||
|  |  | |||
|  | @ -1,4 +1,5 @@ | |||
| import struct | ||||
| import logging | ||||
| 
 | ||||
| from . import TLObject, GzipPacked | ||||
| 
 | ||||
|  | @ -11,7 +12,23 @@ class TLMessage(TLObject): | |||
|         self.msg_id = session.get_new_msg_id() | ||||
|         self.seq_no = session.generate_sequence(request.content_related) | ||||
|         self.request = request | ||||
|         self.container_msg_id = None | ||||
|         logging.getLogger(__name__).debug(self) | ||||
| 
 | ||||
|     def to_dict(self, recursive=True): | ||||
|         return { | ||||
|             'msg_id': self.msg_id, | ||||
|             'seq_no': self.seq_no, | ||||
|             'request': self.request, | ||||
|             'container_msg_id': self.container_msg_id, | ||||
|         } | ||||
| 
 | ||||
|     def to_bytes(self): | ||||
|         body = GzipPacked.gzip_if_smaller(self.request) | ||||
|         return struct.pack('<qii', self.msg_id, self.seq_no, len(body)) + body | ||||
| 
 | ||||
|     def __str__(self): | ||||
|         return TLObject.pretty_format(self) | ||||
| 
 | ||||
|     def stringify(self): | ||||
|         return TLObject.pretty_format(self, indent=0) | ||||
|  |  | |||
|  | @ -1,12 +1,10 @@ | |||
| from datetime import datetime | ||||
| from threading import Event | ||||
| 
 | ||||
| 
 | ||||
| class TLObject: | ||||
|     def __init__(self): | ||||
|         self.request_msg_id = 0  # Long | ||||
| 
 | ||||
|         self.confirm_received = Event() | ||||
|         self.confirm_received = None | ||||
|         self.rpc_error = None | ||||
| 
 | ||||
|         # These should be overrode | ||||
|  |  | |||
|  | @ -1,8 +1,8 @@ | |||
| import logging | ||||
| import pickle | ||||
| import asyncio | ||||
| from collections import deque | ||||
| from datetime import datetime | ||||
| from threading import RLock, Event, Thread | ||||
| 
 | ||||
| from .tl import types as tl | ||||
| 
 | ||||
|  | @ -13,177 +13,72 @@ class UpdateState: | |||
|     """ | ||||
|     WORKER_POLL_TIMEOUT = 5.0  # Avoid waiting forever on the workers | ||||
| 
 | ||||
|     def __init__(self, workers=None): | ||||
|         """ | ||||
|         :param workers: This integer parameter has three possible cases: | ||||
|           workers is None: Updates will *not* be stored on self. | ||||
|           workers = 0: Another thread is responsible for calling self.poll() | ||||
|           workers > 0: 'workers' background threads will be spawned, any | ||||
|                        any of them will invoke all the self.handlers. | ||||
|         """ | ||||
|         self._workers = workers | ||||
|         self._worker_threads = [] | ||||
| 
 | ||||
|     def __init__(self, loop=None): | ||||
|         self.handlers = [] | ||||
|         self._updates_lock = RLock() | ||||
|         self._updates_available = Event() | ||||
|         self._updates = deque() | ||||
|         self._latest_updates = deque(maxlen=10) | ||||
|         self._loop = loop if loop else asyncio.get_event_loop() | ||||
| 
 | ||||
|         self._logger = logging.getLogger(__name__) | ||||
| 
 | ||||
|         # https://core.telegram.org/api/updates | ||||
|         self._state = tl.updates.State(0, 0, datetime.now(), 0, 0) | ||||
| 
 | ||||
|     def can_poll(self): | ||||
|         """Returns True if a call to .poll() won't lock""" | ||||
|         return self._updates_available.is_set() | ||||
| 
 | ||||
|     def poll(self, timeout=None): | ||||
|         """Polls an update or blocks until an update object is available. | ||||
|            If 'timeout is not None', it should be a floating point value, | ||||
|            and the method will 'return None' if waiting times out. | ||||
|         """ | ||||
|         if not self._updates_available.wait(timeout=timeout): | ||||
|             return | ||||
| 
 | ||||
|         with self._updates_lock: | ||||
|             if not self._updates_available.is_set(): | ||||
|                 return | ||||
| 
 | ||||
|             update = self._updates.popleft() | ||||
|             if not self._updates: | ||||
|                 self._updates_available.clear() | ||||
| 
 | ||||
|         if isinstance(update, Exception): | ||||
|             raise update  # Some error was set through (surely StopIteration) | ||||
| 
 | ||||
|         return update | ||||
| 
 | ||||
|     def get_workers(self): | ||||
|         return self._workers | ||||
| 
 | ||||
|     def set_workers(self, n): | ||||
|         """Changes the number of workers running. | ||||
|            If 'n is None', clears all pending updates from memory. | ||||
|         """ | ||||
|         self.stop_workers() | ||||
|         self._workers = n | ||||
|         if n is None: | ||||
|             self._updates.clear() | ||||
|         else: | ||||
|             self.setup_workers() | ||||
| 
 | ||||
|     workers = property(fget=get_workers, fset=set_workers) | ||||
| 
 | ||||
|     def stop_workers(self): | ||||
|         """Raises "StopIterationException" on the worker threads to stop them, | ||||
|            and also clears all of them off the list | ||||
|         """ | ||||
|         if self._workers: | ||||
|             with self._updates_lock: | ||||
|                 # Insert at the beginning so the very next poll causes an error | ||||
|                 # on all the worker threads | ||||
|                 # TODO Should this reset the pts and such? | ||||
|                 for _ in range(self._workers): | ||||
|                     self._updates.appendleft(StopIteration()) | ||||
|                 self._updates_available.set() | ||||
| 
 | ||||
|         for t in self._worker_threads: | ||||
|             t.join() | ||||
| 
 | ||||
|         self._worker_threads.clear() | ||||
| 
 | ||||
|     def setup_workers(self): | ||||
|         if self._worker_threads or not self._workers: | ||||
|             # There already are workers, or workers is None or 0. Do nothing. | ||||
|             return | ||||
| 
 | ||||
|         for i in range(self._workers): | ||||
|             thread = Thread( | ||||
|                 target=UpdateState._worker_loop, | ||||
|                 name='UpdateWorker{}'.format(i), | ||||
|                 daemon=True, | ||||
|                 args=(self, i) | ||||
|             ) | ||||
|             self._worker_threads.append(thread) | ||||
|             thread.start() | ||||
| 
 | ||||
|     def _worker_loop(self, wid): | ||||
|         while True: | ||||
|             try: | ||||
|                 update = self.poll(timeout=UpdateState.WORKER_POLL_TIMEOUT) | ||||
|                 # TODO Maybe people can add different handlers per update type | ||||
|                 if update: | ||||
|                     for handler in self.handlers: | ||||
|                         handler(update) | ||||
|             except StopIteration: | ||||
|                 break | ||||
|             except Exception as e: | ||||
|                 # We don't want to crash a worker thread due to any reason | ||||
|                 self._logger.debug( | ||||
|                     '[ERROR] Unhandled exception on worker {}'.format(wid), e | ||||
|                 ) | ||||
|     def handle_update(self, update): | ||||
|         for handler in self.handlers: | ||||
|             asyncio.ensure_future(handler(update), loop=self._loop) | ||||
| 
 | ||||
|     def process(self, update): | ||||
|         """Processes an update object. This method is normally called by | ||||
|            the library itself. | ||||
|         """ | ||||
|         if self._workers is None: | ||||
|             return  # No processing needs to be done if nobody's working | ||||
|         if isinstance(update, tl.updates.State): | ||||
|             self._state = update | ||||
|             return  # Nothing else to be done | ||||
| 
 | ||||
|         with self._updates_lock: | ||||
|             if isinstance(update, tl.updates.State): | ||||
|                 self._state = update | ||||
|                 return  # Nothing else to be done | ||||
|         pts = getattr(update, 'pts', self._state.pts) | ||||
|         if hasattr(update, 'pts') and pts <= self._state.pts: | ||||
|             return  # We already handled this update | ||||
| 
 | ||||
|             pts = getattr(update, 'pts', self._state.pts) | ||||
|             if hasattr(update, 'pts') and pts <= self._state.pts: | ||||
|                 return  # We already handled this update | ||||
|         self._state.pts = pts | ||||
| 
 | ||||
|             self._state.pts = pts | ||||
|         # TODO There must be a better way to handle updates rather than | ||||
|         # keeping a queue with the latest updates only, and handling | ||||
|         # the 'pts' correctly should be enough. However some updates | ||||
|         # like UpdateUserStatus (even inside UpdateShort) will be called | ||||
|         # repeatedly very often if invoking anything inside an update | ||||
|         # handler. TODO Figure out why. | ||||
|         """ | ||||
|         client = TelegramClient('anon', api_id, api_hash, update_workers=1) | ||||
|         client.connect() | ||||
|         def handle(u): | ||||
|             client.get_me() | ||||
|         client.add_update_handler(handle) | ||||
|         input('Enter to exit.') | ||||
|         """ | ||||
|         data = pickle.dumps(update.to_dict()) | ||||
|         if data in self._latest_updates: | ||||
|             return  # Duplicated too | ||||
| 
 | ||||
|             # TODO There must be a better way to handle updates rather than | ||||
|             # keeping a queue with the latest updates only, and handling | ||||
|             # the 'pts' correctly should be enough. However some updates | ||||
|             # like UpdateUserStatus (even inside UpdateShort) will be called | ||||
|             # repeatedly very often if invoking anything inside an update | ||||
|             # handler. TODO Figure out why. | ||||
|             """ | ||||
|             client = TelegramClient('anon', api_id, api_hash, update_workers=1) | ||||
|             client.connect() | ||||
|             def handle(u): | ||||
|                 client.get_me() | ||||
|             client.add_update_handler(handle) | ||||
|             input('Enter to exit.') | ||||
|             """ | ||||
|             data = pickle.dumps(update.to_dict()) | ||||
|             if data in self._latest_updates: | ||||
|                 return  # Duplicated too | ||||
|         self._latest_updates.append(data) | ||||
| 
 | ||||
|             self._latest_updates.append(data) | ||||
|         if type(update).SUBCLASS_OF_ID == 0x8af52aac:  # crc32(b'Updates') | ||||
|             # Expand "Updates" into "Update", and pass these to callbacks. | ||||
|             # Since .users and .chats have already been processed, we | ||||
|             # don't need to care about those either. | ||||
|             if isinstance(update, tl.UpdateShort): | ||||
|                 self.handle_update(update.update) | ||||
| 
 | ||||
|             if type(update).SUBCLASS_OF_ID == 0x8af52aac:  # crc32(b'Updates') | ||||
|                 # Expand "Updates" into "Update", and pass these to callbacks. | ||||
|                 # Since .users and .chats have already been processed, we | ||||
|                 # don't need to care about those either. | ||||
|                 if isinstance(update, tl.UpdateShort): | ||||
|                     self._updates.append(update.update) | ||||
|                     self._updates_available.set() | ||||
|             elif isinstance(update, (tl.Updates, tl.UpdatesCombined)): | ||||
|                 for upd in update.updates: | ||||
|                     self.handle_update(upd) | ||||
| 
 | ||||
|                 elif isinstance(update, (tl.Updates, tl.UpdatesCombined)): | ||||
|                     self._updates.extend(update.updates) | ||||
|                     self._updates_available.set() | ||||
|             elif not isinstance(update, tl.UpdatesTooLong): | ||||
|                 # TODO Handle "Updates too long" | ||||
|                 self.handle_update(update) | ||||
| 
 | ||||
|                 elif not isinstance(update, tl.UpdatesTooLong): | ||||
|                     # TODO Handle "Updates too long" | ||||
|                     self._updates.append(update) | ||||
|                     self._updates_available.set() | ||||
| 
 | ||||
|             elif type(update).SUBCLASS_OF_ID == 0x9f89304e:  # crc32(b'Update') | ||||
|                 self._updates.append(update) | ||||
|                 self._updates_available.set() | ||||
|             else: | ||||
|                 self._logger.debug('Ignoring "update" of type {}'.format( | ||||
|                     type(update).__name__) | ||||
|                 ) | ||||
|         elif type(update).SUBCLASS_OF_ID == 0x9f89304e:  # crc32(b'Update') | ||||
|             self.handle_update(update) | ||||
|         else: | ||||
|             self._logger.debug('Ignoring "update" of type {}'.format( | ||||
|                 type(update).__name__) | ||||
|             ) | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue
	
	Block a user