mirror of
				https://github.com/LonamiWebs/Telethon.git
				synced 2025-10-30 23:47:33 +03:00 
			
		
		
		
	Update handlers works; it also seems stable
This commit is contained in:
		
							parent
							
								
									917665852d
								
							
						
					
					
						commit
						780e0ceddf
					
				|  | @ -5,13 +5,12 @@ import socket | ||||||
| from datetime import timedelta | from datetime import timedelta | ||||||
| from io import BytesIO, BufferedWriter | from io import BytesIO, BufferedWriter | ||||||
| 
 | 
 | ||||||
| loop = asyncio.get_event_loop() |  | ||||||
| 
 |  | ||||||
| 
 | 
 | ||||||
| class TcpClient: | class TcpClient: | ||||||
|     def __init__(self, proxy=None, timeout=timedelta(seconds=5)): |     def __init__(self, proxy=None, timeout=timedelta(seconds=5), loop=None): | ||||||
|         self.proxy = proxy |         self.proxy = proxy | ||||||
|         self._socket = None |         self._socket = None | ||||||
|  |         self._loop = loop if loop else asyncio.get_event_loop() | ||||||
| 
 | 
 | ||||||
|         if isinstance(timeout, timedelta): |         if isinstance(timeout, timedelta): | ||||||
|             self.timeout = timeout.seconds |             self.timeout = timeout.seconds | ||||||
|  | @ -31,7 +30,7 @@ class TcpClient: | ||||||
|             else:  # tuple, list, etc. |             else:  # tuple, list, etc. | ||||||
|                 self._socket.set_proxy(*self.proxy) |                 self._socket.set_proxy(*self.proxy) | ||||||
| 
 | 
 | ||||||
|         self._socket.settimeout(self.timeout) |         self._socket.setblocking(False) | ||||||
| 
 | 
 | ||||||
|     async def connect(self, ip, port): |     async def connect(self, ip, port): | ||||||
|         """Connects to the specified IP and port number. |         """Connects to the specified IP and port number. | ||||||
|  | @ -42,20 +41,27 @@ class TcpClient: | ||||||
|         else: |         else: | ||||||
|             mode, address = socket.AF_INET, (ip, port) |             mode, address = socket.AF_INET, (ip, port) | ||||||
| 
 | 
 | ||||||
|  |         timeout = 1 | ||||||
|         while True: |         while True: | ||||||
|             try: |             try: | ||||||
|                 while not self._socket: |                 if not self._socket: | ||||||
|                     self._recreate_socket(mode) |                     self._recreate_socket(mode) | ||||||
| 
 | 
 | ||||||
|                 await loop.sock_connect(self._socket, address) |                 await self._loop.sock_connect(self._socket, address) | ||||||
|                 break  # Successful connection, stop retrying to connect |                 break  # Successful connection, stop retrying to connect | ||||||
|  |             except ConnectionError: | ||||||
|  |                 self._socket = None | ||||||
|  |                 await asyncio.sleep(min(timeout, 15)) | ||||||
|  |                 timeout *= 2 | ||||||
|             except OSError as e: |             except OSError as e: | ||||||
|                 # There are some errors that we know how to handle, and |                 # There are some errors that we know how to handle, and | ||||||
|                 # the loop will allow us to retry |                 # the loop will allow us to retry | ||||||
|                 if e.errno == errno.EBADF: |                 if e.errno in [errno.EBADF, errno.ENOTSOCK, errno.EINVAL]: | ||||||
|                     # Bad file descriptor, i.e. socket was closed, set it |                     # Bad file descriptor, i.e. socket was closed, set it | ||||||
|                     # to none to recreate it on the next iteration |                     # to none to recreate it on the next iteration | ||||||
|                     self._socket = None |                     self._socket = None | ||||||
|  |                     await asyncio.sleep(min(timeout, 15)) | ||||||
|  |                     timeout *= 2 | ||||||
|                 else: |                 else: | ||||||
|                     raise |                     raise | ||||||
| 
 | 
 | ||||||
|  | @ -81,13 +87,14 @@ class TcpClient: | ||||||
|             raise ConnectionResetError() |             raise ConnectionResetError() | ||||||
| 
 | 
 | ||||||
|         try: |         try: | ||||||
|             await loop.sock_sendall(self._socket, data) |             await asyncio.wait_for(self._loop.sock_sendall(self._socket, data), | ||||||
|         except socket.timeout as e: |                                    timeout=self.timeout, loop=self._loop) | ||||||
|  |         except asyncio.TimeoutError as e: | ||||||
|             raise TimeoutError() from e |             raise TimeoutError() from e | ||||||
|         except BrokenPipeError: |         except BrokenPipeError: | ||||||
|             self._raise_connection_reset() |             self._raise_connection_reset() | ||||||
|         except OSError as e: |         except OSError as e: | ||||||
|             if e.errno == errno.EBADF: |             if e.errno in [errno.EBADF, errno.ENOTSOCK, errno.ENETUNREACH, errno.EINVAL, errno.ENOTCONN]: | ||||||
|                 self._raise_connection_reset() |                 self._raise_connection_reset() | ||||||
|             else: |             else: | ||||||
|                 raise |                 raise | ||||||
|  | @ -104,11 +111,12 @@ class TcpClient: | ||||||
|             bytes_left = size |             bytes_left = size | ||||||
|             while bytes_left != 0: |             while bytes_left != 0: | ||||||
|                 try: |                 try: | ||||||
|                     partial = await loop.sock_recv(self._socket, bytes_left) |                     partial = await asyncio.wait_for(self._loop.sock_recv(self._socket, bytes_left), | ||||||
|                 except socket.timeout as e: |                                                      timeout=self.timeout, loop=self._loop) | ||||||
|  |                 except asyncio.TimeoutError as e: | ||||||
|                     raise TimeoutError() from e |                     raise TimeoutError() from e | ||||||
|                 except OSError as e: |                 except OSError as e: | ||||||
|                     if e.errno == errno.EBADF or e.errno == errno.ENOTSOCK: |                     if e.errno in [errno.EBADF, errno.ENOTSOCK, errno.ENETUNREACH, errno.EINVAL, errno.ENOTCONN]: | ||||||
|                         self._raise_connection_reset() |                         self._raise_connection_reset() | ||||||
|                     else: |                     else: | ||||||
|                         raise |                         raise | ||||||
|  |  | ||||||
|  | @ -43,13 +43,13 @@ class Connection: | ||||||
|     """ |     """ | ||||||
| 
 | 
 | ||||||
|     def __init__(self, mode=ConnectionMode.TCP_FULL, |     def __init__(self, mode=ConnectionMode.TCP_FULL, | ||||||
|                  proxy=None, timeout=timedelta(seconds=5)): |                  proxy=None, timeout=timedelta(seconds=5), loop=None): | ||||||
|         self._mode = mode |         self._mode = mode | ||||||
|         self._send_counter = 0 |         self._send_counter = 0 | ||||||
|         self._aes_encrypt, self._aes_decrypt = None, None |         self._aes_encrypt, self._aes_decrypt = None, None | ||||||
| 
 | 
 | ||||||
|         # TODO Rename "TcpClient" as some sort of generic socket? |         # TODO Rename "TcpClient" as some sort of generic socket? | ||||||
|         self.conn = TcpClient(proxy=proxy, timeout=timeout) |         self.conn = TcpClient(proxy=proxy, timeout=timeout, loop=loop) | ||||||
| 
 | 
 | ||||||
|         # Sending messages |         # Sending messages | ||||||
|         if mode == ConnectionMode.TCP_FULL: |         if mode == ConnectionMode.TCP_FULL: | ||||||
|  | @ -206,7 +206,7 @@ class Connection: | ||||||
|         return await self.conn.read(length) |         return await self.conn.read(length) | ||||||
| 
 | 
 | ||||||
|     async def _read_obfuscated(self, length): |     async def _read_obfuscated(self, length): | ||||||
|         return await self._aes_decrypt.encrypt(self.conn.read(length)) |         return self._aes_decrypt.encrypt(await self.conn.read(length)) | ||||||
| 
 | 
 | ||||||
|     # endregion |     # endregion | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -1,6 +1,8 @@ | ||||||
| import gzip | import gzip | ||||||
| import logging | import logging | ||||||
| import struct | import struct | ||||||
|  | import asyncio | ||||||
|  | from asyncio import Event | ||||||
| 
 | 
 | ||||||
| from .. import helpers as utils | from .. import helpers as utils | ||||||
| from ..crypto import AES | from ..crypto import AES | ||||||
|  | @ -30,17 +32,15 @@ class MtProtoSender: | ||||||
|                   in parallel, so thread-safety (hence locking) isn't needed. |                   in parallel, so thread-safety (hence locking) isn't needed. | ||||||
|     """ |     """ | ||||||
| 
 | 
 | ||||||
|     def __init__(self, session, connection): |     def __init__(self, session, connection, loop=None): | ||||||
|         """Creates a new MtProtoSender configured to send messages through |         """Creates a new MtProtoSender configured to send messages through | ||||||
|            'connection' and using the parameters from 'session'. |            'connection' and using the parameters from 'session'. | ||||||
|         """ |         """ | ||||||
|         self.session = session |         self.session = session | ||||||
|         self.connection = connection |         self.connection = connection | ||||||
|  |         self._loop = loop if loop else asyncio.get_event_loop() | ||||||
|         self._logger = logging.getLogger(__name__) |         self._logger = logging.getLogger(__name__) | ||||||
| 
 | 
 | ||||||
|         # Message IDs that need confirmation |  | ||||||
|         self._need_confirmation = [] |  | ||||||
| 
 |  | ||||||
|         # Requests (as msg_id: Message) sent waiting to be received |         # Requests (as msg_id: Message) sent waiting to be received | ||||||
|         self._pending_receive = {} |         self._pending_receive = {} | ||||||
| 
 | 
 | ||||||
|  | @ -54,12 +54,11 @@ class MtProtoSender: | ||||||
|     def disconnect(self): |     def disconnect(self): | ||||||
|         """Disconnects from the server""" |         """Disconnects from the server""" | ||||||
|         self.connection.close() |         self.connection.close() | ||||||
|         self._need_confirmation.clear() |  | ||||||
|         self._clear_all_pending() |         self._clear_all_pending() | ||||||
| 
 | 
 | ||||||
|     def clone(self): |     def clone(self): | ||||||
|         """Creates a copy of this MtProtoSender as a new connection""" |         """Creates a copy of this MtProtoSender as a new connection""" | ||||||
|         return MtProtoSender(self.session, self.connection.clone()) |         return MtProtoSender(self.session, self.connection.clone(), self._loop) | ||||||
| 
 | 
 | ||||||
|     # region Send and receive |     # region Send and receive | ||||||
| 
 | 
 | ||||||
|  | @ -67,21 +66,23 @@ class MtProtoSender: | ||||||
|         """Sends the specified MTProtoRequest, previously sending any message |         """Sends the specified MTProtoRequest, previously sending any message | ||||||
|            which needed confirmation.""" |            which needed confirmation.""" | ||||||
| 
 | 
 | ||||||
|  |         # Prepare the event of every request | ||||||
|  |         for r in requests: | ||||||
|  |             if r.confirm_received is None: | ||||||
|  |                 r.confirm_received = Event(loop=self._loop) | ||||||
|  |             else: | ||||||
|  |                 r.confirm_received.clear() | ||||||
|  | 
 | ||||||
|         # Finally send our packed request(s) |         # Finally send our packed request(s) | ||||||
|         messages = [TLMessage(self.session, r) for r in requests] |         messages = [TLMessage(self.session, r) for r in requests] | ||||||
|         self._pending_receive.update({m.msg_id: m for m in messages}) |         self._pending_receive.update({m.msg_id: m for m in messages}) | ||||||
| 
 | 
 | ||||||
|         # Pack everything in the same container if we need to send AckRequests |  | ||||||
|         if self._need_confirmation: |  | ||||||
|             messages.append( |  | ||||||
|                 TLMessage(self.session, MsgsAck(self._need_confirmation)) |  | ||||||
|             ) |  | ||||||
|             self._need_confirmation.clear() |  | ||||||
| 
 |  | ||||||
|         if len(messages) == 1: |         if len(messages) == 1: | ||||||
|             message = messages[0] |             message = messages[0] | ||||||
|         else: |         else: | ||||||
|             message = TLMessage(self.session, MessageContainer(messages)) |             message = TLMessage(self.session, MessageContainer(messages)) | ||||||
|  |             for m in messages: | ||||||
|  |                 m.container_msg_id = message.msg_id | ||||||
| 
 | 
 | ||||||
|         await self._send_message(message) |         await self._send_message(message) | ||||||
| 
 | 
 | ||||||
|  | @ -115,6 +116,7 @@ class MtProtoSender: | ||||||
|         message, remote_msg_id, remote_seq = self._decode_msg(body) |         message, remote_msg_id, remote_seq = self._decode_msg(body) | ||||||
|         with BinaryReader(message) as reader: |         with BinaryReader(message) as reader: | ||||||
|             await self._process_msg(remote_msg_id, remote_seq, reader, update_state) |             await self._process_msg(remote_msg_id, remote_seq, reader, update_state) | ||||||
|  |             await self._send_acknowledge(remote_msg_id) | ||||||
| 
 | 
 | ||||||
|     # endregion |     # endregion | ||||||
| 
 | 
 | ||||||
|  | @ -174,7 +176,6 @@ class MtProtoSender: | ||||||
|         """ |         """ | ||||||
| 
 | 
 | ||||||
|         # TODO Check salt, session_id and sequence_number |         # TODO Check salt, session_id and sequence_number | ||||||
|         self._need_confirmation.append(msg_id) |  | ||||||
| 
 | 
 | ||||||
|         code = reader.read_int(signed=False) |         code = reader.read_int(signed=False) | ||||||
|         reader.seek(-4) |         reader.seek(-4) | ||||||
|  | @ -210,14 +211,14 @@ class MtProtoSender: | ||||||
|         if code == MsgsAck.CONSTRUCTOR_ID:  # may handle the request we wanted |         if code == MsgsAck.CONSTRUCTOR_ID:  # may handle the request we wanted | ||||||
|             ack = reader.tgread_object() |             ack = reader.tgread_object() | ||||||
|             assert isinstance(ack, MsgsAck) |             assert isinstance(ack, MsgsAck) | ||||||
|             # Ignore every ack request *unless* when logging out, when it's |             # Ignore every ack request *unless* when logging out, | ||||||
|             # when it seems to only make sense. We also need to set a non-None |             # when it seems to only make sense. We also need to set a non-None | ||||||
|             # result since Telegram doesn't send the response for these. |             # result since Telegram doesn't send the response for these. | ||||||
|             for msg_id in ack.msg_ids: |             for msg_id in ack.msg_ids: | ||||||
|                 r = self._pop_request_of_type(msg_id, LogOutRequest) |                 r = self._pop_request_of_type(msg_id, LogOutRequest) | ||||||
|                 if r: |                 if r: | ||||||
|                     r.result = True  # Telegram won't send this value |                     r.result = True  # Telegram won't send this value | ||||||
|                     r.confirm_received() |                     r.confirm_received.set() | ||||||
|                     self._logger.debug('Message ack confirmed', r) |                     self._logger.debug('Message ack confirmed', r) | ||||||
| 
 | 
 | ||||||
|             return True |             return True | ||||||
|  | @ -259,11 +260,29 @@ class MtProtoSender: | ||||||
|         if message and isinstance(message.request, t): |         if message and isinstance(message.request, t): | ||||||
|             return self._pending_receive.pop(msg_id).request |             return self._pending_receive.pop(msg_id).request | ||||||
| 
 | 
 | ||||||
|  |     def _pop_requests_of_container(self, container_msg_id): | ||||||
|  |         msgs = [msg for msg in self._pending_receive.values() if msg.container_msg_id == container_msg_id] | ||||||
|  |         requests = [msg.request for msg in msgs] | ||||||
|  |         for msg in msgs: | ||||||
|  |             self._pending_receive.pop(msg.msg_id, None) | ||||||
|  |         return requests | ||||||
|  | 
 | ||||||
|     def _clear_all_pending(self): |     def _clear_all_pending(self): | ||||||
|         for r in self._pending_receive.values(): |         for r in self._pending_receive.values(): | ||||||
|             r.confirm_received.set() |             r.request.confirm_received.set() | ||||||
|         self._pending_receive.clear() |         self._pending_receive.clear() | ||||||
| 
 | 
 | ||||||
|  |     async def _resend_request(self, msg_id): | ||||||
|  |         request = self._pop_request(msg_id) | ||||||
|  |         if request: | ||||||
|  |             self._logger.debug('requests is about to resend') | ||||||
|  |             await self.send(request) | ||||||
|  |             return | ||||||
|  |         requests = self._pop_requests_of_container(msg_id) | ||||||
|  |         if requests: | ||||||
|  |             self._logger.debug('container of requests is about to resend') | ||||||
|  |             await self.send(*requests) | ||||||
|  | 
 | ||||||
|     async def _handle_pong(self, msg_id, sequence, reader): |     async def _handle_pong(self, msg_id, sequence, reader): | ||||||
|         self._logger.debug('Handling pong') |         self._logger.debug('Handling pong') | ||||||
|         pong = reader.tgread_object() |         pong = reader.tgread_object() | ||||||
|  | @ -303,10 +322,9 @@ class MtProtoSender: | ||||||
|         self.session.salt = struct.unpack( |         self.session.salt = struct.unpack( | ||||||
|             '<Q', struct.pack('<q', bad_salt.new_server_salt) |             '<Q', struct.pack('<q', bad_salt.new_server_salt) | ||||||
|         )[0] |         )[0] | ||||||
|  |         self.session.save() | ||||||
| 
 | 
 | ||||||
|         request = self._pop_request(bad_salt.bad_msg_id) |         await self._resend_request(bad_salt.bad_msg_id) | ||||||
|         if request: |  | ||||||
|             await self.send(request) |  | ||||||
| 
 | 
 | ||||||
|         return True |         return True | ||||||
| 
 | 
 | ||||||
|  | @ -322,15 +340,18 @@ class MtProtoSender: | ||||||
|             self.session.update_time_offset(correct_msg_id=msg_id) |             self.session.update_time_offset(correct_msg_id=msg_id) | ||||||
|             self._logger.debug('Read Bad Message error: ' + str(error)) |             self._logger.debug('Read Bad Message error: ' + str(error)) | ||||||
|             self._logger.debug('Attempting to use the correct time offset.') |             self._logger.debug('Attempting to use the correct time offset.') | ||||||
|  |             await self._resend_request(bad_msg.bad_msg_id) | ||||||
|             return True |             return True | ||||||
|         elif bad_msg.error_code == 32: |         elif bad_msg.error_code == 32: | ||||||
|             # msg_seqno too low, so just pump it up by some "large" amount |             # msg_seqno too low, so just pump it up by some "large" amount | ||||||
|             # TODO A better fix would be to start with a new fresh session ID |             # TODO A better fix would be to start with a new fresh session ID | ||||||
|             self.session._sequence += 64 |             self.session._sequence += 64 | ||||||
|  |             await self._resend_request(bad_msg.bad_msg_id) | ||||||
|             return True |             return True | ||||||
|         elif bad_msg.error_code == 33: |         elif bad_msg.error_code == 33: | ||||||
|             # msg_seqno too high never seems to happen but just in case |             # msg_seqno too high never seems to happen but just in case | ||||||
|             self.session._sequence -= 16 |             self.session._sequence -= 16 | ||||||
|  |             await self._resend_request(bad_msg.bad_msg_id) | ||||||
|             return True |             return True | ||||||
|         else: |         else: | ||||||
|             raise error |             raise error | ||||||
|  | @ -341,7 +362,6 @@ class MtProtoSender: | ||||||
| 
 | 
 | ||||||
|         # TODO For now, simply ack msg_new.answer_msg_id |         # TODO For now, simply ack msg_new.answer_msg_id | ||||||
|         # Relevant tdesktop source code: https://goo.gl/VvpCC6 |         # Relevant tdesktop source code: https://goo.gl/VvpCC6 | ||||||
|         await self._send_acknowledge(msg_new.answer_msg_id) |  | ||||||
|         return True |         return True | ||||||
| 
 | 
 | ||||||
|     async def _handle_msg_new_detailed_info(self, msg_id, sequence, reader): |     async def _handle_msg_new_detailed_info(self, msg_id, sequence, reader): | ||||||
|  | @ -350,7 +370,6 @@ class MtProtoSender: | ||||||
| 
 | 
 | ||||||
|         # TODO For now, simply ack msg_new.answer_msg_id |         # TODO For now, simply ack msg_new.answer_msg_id | ||||||
|         # Relevant tdesktop source code: https://goo.gl/G7DPsR |         # Relevant tdesktop source code: https://goo.gl/G7DPsR | ||||||
|         await self._send_acknowledge(msg_new.answer_msg_id) |  | ||||||
|         return True |         return True | ||||||
| 
 | 
 | ||||||
|     async def _handle_new_session_created(self, msg_id, sequence, reader): |     async def _handle_new_session_created(self, msg_id, sequence, reader): | ||||||
|  | @ -378,9 +397,6 @@ class MtProtoSender: | ||||||
|                     reader.read_int(), reader.tgread_string() |                     reader.read_int(), reader.tgread_string() | ||||||
|                 ) |                 ) | ||||||
| 
 | 
 | ||||||
|             # Acknowledge that we received the error |  | ||||||
|             await self._send_acknowledge(request_id) |  | ||||||
| 
 |  | ||||||
|             if request: |             if request: | ||||||
|                 request.rpc_error = error |                 request.rpc_error = error | ||||||
|                 request.confirm_received.set() |                 request.confirm_received.set() | ||||||
|  |  | ||||||
|  | @ -1,10 +1,10 @@ | ||||||
| import logging | import logging | ||||||
| import os | import os | ||||||
| import warnings | import asyncio | ||||||
| from datetime import timedelta, datetime | from datetime import timedelta, datetime | ||||||
| from hashlib import md5 | from hashlib import md5 | ||||||
| from io import BytesIO | from io import BytesIO | ||||||
| from time import sleep | from asyncio import Lock | ||||||
| 
 | 
 | ||||||
| from . import helpers as utils | from . import helpers as utils | ||||||
| from .crypto import rsa, CdnDecrypter | from .crypto import rsa, CdnDecrypter | ||||||
|  | @ -17,7 +17,7 @@ from .network import authenticator, MtProtoSender, Connection, ConnectionMode | ||||||
| from .tl import TLObject, Session | from .tl import TLObject, Session | ||||||
| from .tl.all_tlobjects import LAYER | from .tl.all_tlobjects import LAYER | ||||||
| from .tl.functions import ( | from .tl.functions import ( | ||||||
|     InitConnectionRequest, InvokeWithLayerRequest |     InitConnectionRequest, InvokeWithLayerRequest, PingRequest | ||||||
| ) | ) | ||||||
| from .tl.functions.auth import ( | from .tl.functions.auth import ( | ||||||
|     ImportAuthorizationRequest, ExportAuthorizationRequest |     ImportAuthorizationRequest, ExportAuthorizationRequest | ||||||
|  | @ -67,6 +67,7 @@ class TelegramBareClient: | ||||||
|                  connection_mode=ConnectionMode.TCP_FULL, |                  connection_mode=ConnectionMode.TCP_FULL, | ||||||
|                  proxy=None, |                  proxy=None, | ||||||
|                  timeout=timedelta(seconds=5), |                  timeout=timedelta(seconds=5), | ||||||
|  |                  loop=None, | ||||||
|                  **kwargs): |                  **kwargs): | ||||||
|         """Refer to TelegramClient.__init__ for docs on this method""" |         """Refer to TelegramClient.__init__ for docs on this method""" | ||||||
|         if not api_id or not api_hash: |         if not api_id or not api_hash: | ||||||
|  | @ -82,6 +83,8 @@ class TelegramBareClient: | ||||||
|                 'The given session must be a str or a Session instance.' |                 'The given session must be a str or a Session instance.' | ||||||
|             ) |             ) | ||||||
| 
 | 
 | ||||||
|  |         self._loop = loop if loop else asyncio.get_event_loop() | ||||||
|  | 
 | ||||||
|         self.session = session |         self.session = session | ||||||
|         self.api_id = int(api_id) |         self.api_id = int(api_id) | ||||||
|         self.api_hash = api_hash |         self.api_hash = api_hash | ||||||
|  | @ -92,12 +95,18 @@ class TelegramBareClient: | ||||||
|         # that calls .connect(). Every other thread will spawn a new |         # that calls .connect(). Every other thread will spawn a new | ||||||
|         # temporary connection. The connection on this one is always |         # temporary connection. The connection on this one is always | ||||||
|         # kept open so Telegram can send us updates. |         # kept open so Telegram can send us updates. | ||||||
|         self._sender = MtProtoSender(self.session, Connection( |         self._sender = MtProtoSender( | ||||||
|             mode=connection_mode, proxy=proxy, timeout=timeout |             self.session, | ||||||
|         )) |             Connection(mode=connection_mode, proxy=proxy, timeout=timeout, loop=self._loop), | ||||||
|  |             self._loop | ||||||
|  |         ) | ||||||
| 
 | 
 | ||||||
|         self._logger = logging.getLogger(__name__) |         self._logger = logging.getLogger(__name__) | ||||||
| 
 | 
 | ||||||
|  |         # Two coroutines may be calling reconnect() when the connection is lost, | ||||||
|  |         # we only want one to actually perform the reconnection. | ||||||
|  |         self._reconnect_lock = Lock(loop=self._loop) | ||||||
|  | 
 | ||||||
|         # Cache "exported" sessions as 'dc_id: Session' not to recreate |         # Cache "exported" sessions as 'dc_id: Session' not to recreate | ||||||
|         # them all the time since generating a new key is a relatively |         # them all the time since generating a new key is a relatively | ||||||
|         # expensive operation. |         # expensive operation. | ||||||
|  | @ -105,7 +114,7 @@ class TelegramBareClient: | ||||||
| 
 | 
 | ||||||
|         # This member will process updates if enabled. |         # This member will process updates if enabled. | ||||||
|         # One may change self.updates.enabled at any later point. |         # One may change self.updates.enabled at any later point. | ||||||
|         self.updates = UpdateState(workers=None) |         self.updates = UpdateState(self._loop) | ||||||
| 
 | 
 | ||||||
|         # Used on connection - the user may modify these and reconnect |         # Used on connection - the user may modify these and reconnect | ||||||
|         kwargs['app_version'] = kwargs.get('app_version', self.__version__) |         kwargs['app_version'] = kwargs.get('app_version', self.__version__) | ||||||
|  | @ -129,10 +138,11 @@ class TelegramBareClient: | ||||||
|         # Uploaded files cache so subsequent calls are instant |         # Uploaded files cache so subsequent calls are instant | ||||||
|         self._upload_cache = {} |         self._upload_cache = {} | ||||||
| 
 | 
 | ||||||
|         # Default PingRequest delay |         self._recv_loop = None | ||||||
|         self._last_ping = datetime.now() |         self._ping_loop = None | ||||||
|         self._ping_delay = timedelta(minutes=1) |  | ||||||
| 
 | 
 | ||||||
|  |         # Default PingRequest delay | ||||||
|  |         self._ping_delay = timedelta(minutes=1) | ||||||
| 
 | 
 | ||||||
|     # endregion |     # endregion | ||||||
| 
 | 
 | ||||||
|  | @ -167,6 +177,7 @@ class TelegramBareClient: | ||||||
|                     self.session.auth_key, self.session.time_offset = \ |                     self.session.auth_key, self.session.time_offset = \ | ||||||
|                         await authenticator.do_authentication(self._sender.connection) |                         await authenticator.do_authentication(self._sender.connection) | ||||||
|                 except BrokenAuthKeyError: |                 except BrokenAuthKeyError: | ||||||
|  |                     self._user_connected = False | ||||||
|                     return False |                     return False | ||||||
| 
 | 
 | ||||||
|                 self.session.layer = LAYER |                 self.session.layer = LAYER | ||||||
|  | @ -198,12 +209,12 @@ class TelegramBareClient: | ||||||
|             # another data center and this would raise UserMigrateError) |             # another data center and this would raise UserMigrateError) | ||||||
|             # to also assert whether the user is logged in or not. |             # to also assert whether the user is logged in or not. | ||||||
|             self._user_connected = True |             self._user_connected = True | ||||||
|             if _sync_updates and not _cdn: |             if _sync_updates and not _cdn and not self._authorized: | ||||||
|                 try: |                 try: | ||||||
|                     await self.sync_updates() |                     await self.sync_updates() | ||||||
|                     self._set_connected_and_authorized() |                     self._set_connected_and_authorized() | ||||||
|                 except UnauthorizedError: |                 except UnauthorizedError: | ||||||
|                     self._authorized = False |                     pass | ||||||
| 
 | 
 | ||||||
|             return True |             return True | ||||||
| 
 | 
 | ||||||
|  | @ -211,7 +222,7 @@ class TelegramBareClient: | ||||||
|             # This is fine, probably layer migration |             # This is fine, probably layer migration | ||||||
|             self._logger.debug('Found invalid item, probably migrating', e) |             self._logger.debug('Found invalid item, probably migrating', e) | ||||||
|             self.disconnect() |             self.disconnect() | ||||||
|             return self.connect( |             return await self.connect( | ||||||
|                 _exported_auth=_exported_auth, |                 _exported_auth=_exported_auth, | ||||||
|                 _sync_updates=_sync_updates, |                 _sync_updates=_sync_updates, | ||||||
|                 _cdn=_cdn |                 _cdn=_cdn | ||||||
|  | @ -261,7 +272,17 @@ class TelegramBareClient: | ||||||
|         """ |         """ | ||||||
|         if new_dc is None: |         if new_dc is None: | ||||||
|             # Assume we are disconnected due to some error, so connect again |             # Assume we are disconnected due to some error, so connect again | ||||||
|             return await self.connect() |             try: | ||||||
|  |                 await self._reconnect_lock.acquire() | ||||||
|  |                 # Another thread may have connected again, so check that first | ||||||
|  |                 if self.is_connected(): | ||||||
|  |                     return True | ||||||
|  | 
 | ||||||
|  |                 return await self.connect() | ||||||
|  |             except ConnectionResetError: | ||||||
|  |                 return False | ||||||
|  |             finally: | ||||||
|  |                 self._reconnect_lock.release() | ||||||
|         else: |         else: | ||||||
|             self.disconnect() |             self.disconnect() | ||||||
|             self.session.auth_key = None  # Force creating new auth_key |             self.session.auth_key = None  # Force creating new auth_key | ||||||
|  | @ -337,7 +358,8 @@ class TelegramBareClient: | ||||||
|         client = TelegramBareClient( |         client = TelegramBareClient( | ||||||
|             session, self.api_id, self.api_hash, |             session, self.api_id, self.api_hash, | ||||||
|             proxy=self._sender.connection.conn.proxy, |             proxy=self._sender.connection.conn.proxy, | ||||||
|             timeout=self._sender.connection.get_timeout() |             timeout=self._sender.connection.get_timeout(), | ||||||
|  |             loop=self._loop | ||||||
|         ) |         ) | ||||||
|         await client.connect(_exported_auth=export_auth, _sync_updates=False) |         await client.connect(_exported_auth=export_auth, _sync_updates=False) | ||||||
|         client._authorized = True  # We exported the auth, so we got auth |         client._authorized = True  # We exported the auth, so we got auth | ||||||
|  | @ -356,7 +378,8 @@ class TelegramBareClient: | ||||||
|         client = TelegramBareClient( |         client = TelegramBareClient( | ||||||
|             session, self.api_id, self.api_hash, |             session, self.api_id, self.api_hash, | ||||||
|             proxy=self._sender.connection.conn.proxy, |             proxy=self._sender.connection.conn.proxy, | ||||||
|             timeout=self._sender.connection.get_timeout() |             timeout=self._sender.connection.get_timeout(), | ||||||
|  |             loop=self._loop | ||||||
|         ) |         ) | ||||||
| 
 | 
 | ||||||
|         # This will make use of the new RSA keys for this specific CDN. |         # This will make use of the new RSA keys for this specific CDN. | ||||||
|  | @ -381,55 +404,52 @@ class TelegramBareClient: | ||||||
|                    x.content_related for x in requests): |                    x.content_related for x in requests): | ||||||
|             raise ValueError('You can only invoke requests, not types!') |             raise ValueError('You can only invoke requests, not types!') | ||||||
| 
 | 
 | ||||||
|         # TODO Determine the sender to be used (main or a new connection) |         # We should call receive from this thread if there's no background | ||||||
|         sender = self._sender  # .clone(), .connect() |         # thread reading or if the server disconnected us and we're trying | ||||||
|  |         # to reconnect. This is because the read thread may either be | ||||||
|  |         # locked also trying to reconnect or we may be said thread already. | ||||||
|  |         call_receive = self._recv_loop is None | ||||||
| 
 | 
 | ||||||
|         try: |         for retry in range(retries): | ||||||
|             for _ in range(retries): |             result = await self._invoke(call_receive, retry, *requests) | ||||||
|                 result = await self._invoke(sender, *requests) |             if result is not None: | ||||||
|                 if result is not None: |                 return result | ||||||
|                     return result |  | ||||||
| 
 | 
 | ||||||
|             raise ValueError('Number of retries reached 0.') |         return None | ||||||
|         finally: |  | ||||||
|             if sender != self._sender: |  | ||||||
|                 sender.disconnect()  # Close temporary connections |  | ||||||
| 
 | 
 | ||||||
|     # Let people use client.invoke(SomeRequest()) instead client(...) |     # Let people use client.invoke(SomeRequest()) instead client(...) | ||||||
|     invoke = __call__ |     invoke = __call__ | ||||||
| 
 | 
 | ||||||
|     async def _invoke(self, sender, *requests): |     async def _invoke(self, call_receive, retry, *requests): | ||||||
|         try: |         try: | ||||||
|             # Ensure that we start with no previous errors (i.e. resending) |             # Ensure that we start with no previous errors (i.e. resending) | ||||||
|             for x in requests: |             for x in requests: | ||||||
|                 x.confirm_received.clear() |  | ||||||
|                 x.rpc_error = None |                 x.rpc_error = None | ||||||
| 
 | 
 | ||||||
|             await sender.send(*requests) |             await self._sender.send(*requests) | ||||||
|             while not all(x.confirm_received.is_set() for x in requests): |  | ||||||
|                 await sender.receive(update_state=self.updates) |  | ||||||
| 
 | 
 | ||||||
|         except TimeoutError: |             if not call_receive: | ||||||
|             pass  # We will just retry |                 await asyncio.wait( | ||||||
|  |                     list(map(lambda x: x.confirm_received.wait(), requests)), | ||||||
|  |                     timeout=self._sender.connection.get_timeout(), | ||||||
|  |                     loop=self._loop | ||||||
|  |                 ) | ||||||
|  |             else: | ||||||
|  |                 while not all(x.confirm_received.is_set() for x in requests): | ||||||
|  |                     await self._sender.receive(update_state=self.updates) | ||||||
| 
 | 
 | ||||||
|         except ConnectionResetError: |         except ConnectionResetError: | ||||||
|             if not self._user_connected: |             if not self._user_connected or self._reconnect_lock.locked(): | ||||||
|                 # Only attempt reconnecting if we're authorized |                 # Only attempt reconnecting if the user called connect and not | ||||||
|  |                 # reconnecting already. | ||||||
|                 raise |                 raise | ||||||
| 
 | 
 | ||||||
|             self._logger.debug('Server disconnected us. Reconnecting and ' |             self._logger.debug('Server disconnected us. Reconnecting and ' | ||||||
|                                'resending request...') |                                'resending request... (%d)' % retry) | ||||||
| 
 |             await self._reconnect() | ||||||
|             if sender != self._sender: |             if not self._sender.is_connected(): | ||||||
|                 # TODO Try reconnecting forever too? |                 await asyncio.sleep(retry + 1, loop=self._loop) | ||||||
|                 await sender.connect() |             return None | ||||||
|             else: |  | ||||||
|                 while self._user_connected and not await self._reconnect(): |  | ||||||
|                     sleep(0.1)  # Retry forever until we can send the request |  | ||||||
| 
 |  | ||||||
|         finally: |  | ||||||
|             if sender != self._sender: |  | ||||||
|                 sender.disconnect() |  | ||||||
| 
 | 
 | ||||||
|         try: |         try: | ||||||
|             raise next(x.rpc_error for x in requests if x.rpc_error) |             raise next(x.rpc_error for x in requests if x.rpc_error) | ||||||
|  | @ -452,7 +472,7 @@ class TelegramBareClient: | ||||||
|             ) |             ) | ||||||
| 
 | 
 | ||||||
|             await self._reconnect(new_dc=e.new_dc) |             await self._reconnect(new_dc=e.new_dc) | ||||||
|             return await self._invoke(sender, *requests) |             return None | ||||||
| 
 | 
 | ||||||
|         except ServerError as e: |         except ServerError as e: | ||||||
|             # Telegram is having some issues, just retry |             # Telegram is having some issues, just retry | ||||||
|  | @ -467,7 +487,8 @@ class TelegramBareClient: | ||||||
|             self._logger.debug( |             self._logger.debug( | ||||||
|                 'Sleep of %d seconds below threshold, sleeping' % e.seconds |                 'Sleep of %d seconds below threshold, sleeping' % e.seconds | ||||||
|             ) |             ) | ||||||
|             sleep(e.seconds) |             await asyncio.sleep(e.seconds, loop=self._loop) | ||||||
|  |             return None | ||||||
| 
 | 
 | ||||||
|     # Some really basic functionality |     # Some really basic functionality | ||||||
| 
 | 
 | ||||||
|  | @ -670,16 +691,13 @@ class TelegramBareClient: | ||||||
|         """ |         """ | ||||||
|         self.updates.process(await self(GetStateRequest())) |         self.updates.process(await self(GetStateRequest())) | ||||||
| 
 | 
 | ||||||
|     def add_update_handler(self, handler): |     async def add_update_handler(self, handler): | ||||||
|         """Adds an update handler (a function which takes a TLObject, |         """Adds an update handler (a function which takes a TLObject, | ||||||
|           an update, as its parameter) and listens for updates""" |           an update, as its parameter) and listens for updates""" | ||||||
|         if not self.updates.get_workers: |  | ||||||
|             warnings.warn("There are no update workers running, so adding an update handler will have no effect.") |  | ||||||
| 
 |  | ||||||
|         sync = not self.updates.handlers |         sync = not self.updates.handlers | ||||||
|         self.updates.handlers.append(handler) |         self.updates.handlers.append(handler) | ||||||
|         if sync: |         if sync: | ||||||
|             self.sync_updates() |             await self.sync_updates() | ||||||
| 
 | 
 | ||||||
|     def remove_update_handler(self, handler): |     def remove_update_handler(self, handler): | ||||||
|         self.updates.handlers.remove(handler) |         self.updates.handlers.remove(handler) | ||||||
|  | @ -693,6 +711,63 @@ class TelegramBareClient: | ||||||
| 
 | 
 | ||||||
|     def _set_connected_and_authorized(self): |     def _set_connected_and_authorized(self): | ||||||
|         self._authorized = True |         self._authorized = True | ||||||
|         # TODO self.updates.setup_workers() |         if self._recv_loop is None: | ||||||
|  |             self._recv_loop = asyncio.ensure_future(self._recv_loop_impl(), loop=self._loop) | ||||||
|  |         if self._ping_loop is None: | ||||||
|  |             self._ping_loop = asyncio.ensure_future(self._ping_loop_impl(), loop=self._loop) | ||||||
|  | 
 | ||||||
|  |     async def _ping_loop_impl(self): | ||||||
|  |         while self._user_connected: | ||||||
|  |             await self(PingRequest(int.from_bytes(os.urandom(8), 'big', signed=True))) | ||||||
|  |             await asyncio.sleep(self._ping_delay.seconds, loop=self._loop) | ||||||
|  |         self._ping_loop = None | ||||||
|  | 
 | ||||||
|  |     async def _recv_loop_impl(self): | ||||||
|  |         need_reconnect = False | ||||||
|  |         timeout = 1 | ||||||
|  |         while self._user_connected: | ||||||
|  |             try: | ||||||
|  |                 if need_reconnect: | ||||||
|  |                     need_reconnect = False | ||||||
|  |                     while self._user_connected and not await self._reconnect(): | ||||||
|  |                         await asyncio.sleep(0.1, loop=self._loop)  # Retry forever, this is instant messaging | ||||||
|  | 
 | ||||||
|  |                 await self._sender.receive(update_state=self.updates) | ||||||
|  |             except TimeoutError: | ||||||
|  |                 # No problem. | ||||||
|  |                 pass | ||||||
|  |             except ConnectionError as error: | ||||||
|  |                 self._logger.debug(error) | ||||||
|  |                 need_reconnect = True | ||||||
|  |                 await asyncio.sleep(min(timeout, 15), loop=self._loop) | ||||||
|  |                 timeout *= 2 | ||||||
|  |             except Exception as error: | ||||||
|  |                 # Unknown exception, pass it to the main thread | ||||||
|  |                 self._logger.debug( | ||||||
|  |                     '[ERROR] Unknown error on the read loop, please report', | ||||||
|  |                     error | ||||||
|  |                 ) | ||||||
|  | 
 | ||||||
|  |                 try: | ||||||
|  |                     import socks | ||||||
|  |                     if isinstance(error, ( | ||||||
|  |                             socks.GeneralProxyError, socks.ProxyConnectionError | ||||||
|  |                     )): | ||||||
|  |                         # This is a known error, and it's not related to | ||||||
|  |                         # Telegram but rather to the proxy. Disconnect and | ||||||
|  |                         # hand it over to the main thread. | ||||||
|  |                         self._background_error = error | ||||||
|  |                         self.disconnect() | ||||||
|  |                         break | ||||||
|  |                 except ImportError: | ||||||
|  |                     "Not using PySocks, so it can't be a socket error" | ||||||
|  | 
 | ||||||
|  |                 # If something strange happens we don't want to enter an | ||||||
|  |                 # infinite loop where all we do is raise an exception, so | ||||||
|  |                 # add a little sleep to avoid the CPU usage going mad. | ||||||
|  |                 await asyncio.sleep(0.1, loop=self._loop) | ||||||
|  |                 break | ||||||
|  |             timeout = 1 | ||||||
|  |         self._recv_loop = None | ||||||
| 
 | 
 | ||||||
|     # endregion |     # endregion | ||||||
|  |  | ||||||
|  | @ -61,6 +61,7 @@ class TelegramClient(TelegramBareClient): | ||||||
|                  connection_mode=ConnectionMode.TCP_FULL, |                  connection_mode=ConnectionMode.TCP_FULL, | ||||||
|                  proxy=None, |                  proxy=None, | ||||||
|                  timeout=timedelta(seconds=5), |                  timeout=timedelta(seconds=5), | ||||||
|  |                  loop=None, | ||||||
|                  **kwargs): |                  **kwargs): | ||||||
|         """Initializes the Telegram client with the specified API ID and Hash. |         """Initializes the Telegram client with the specified API ID and Hash. | ||||||
| 
 | 
 | ||||||
|  | @ -87,6 +88,7 @@ class TelegramClient(TelegramBareClient): | ||||||
|             connection_mode=connection_mode, |             connection_mode=connection_mode, | ||||||
|             proxy=proxy, |             proxy=proxy, | ||||||
|             timeout=timeout, |             timeout=timeout, | ||||||
|  |             loop=loop, | ||||||
|             **kwargs |             **kwargs | ||||||
|         ) |         ) | ||||||
| 
 | 
 | ||||||
|  | @ -104,8 +106,9 @@ class TelegramClient(TelegramBareClient): | ||||||
|         """Sends a code request to the specified phone number""" |         """Sends a code request to the specified phone number""" | ||||||
|         phone = EntityDatabase.parse_phone(phone) or self._phone |         phone = EntityDatabase.parse_phone(phone) or self._phone | ||||||
|         result = await self(SendCodeRequest(phone, self.api_id, self.api_hash)) |         result = await self(SendCodeRequest(phone, self.api_id, self.api_hash)) | ||||||
|         self._phone = phone |         if result: | ||||||
|         self._phone_code_hash = result.phone_code_hash |             self._phone = phone | ||||||
|  |             self._phone_code_hash = result.phone_code_hash | ||||||
|         return result |         return result | ||||||
| 
 | 
 | ||||||
|     async def sign_in(self, phone=None, code=None, |     async def sign_in(self, phone=None, code=None, | ||||||
|  | @ -169,8 +172,10 @@ class TelegramClient(TelegramBareClient): | ||||||
|                 'and a password only if an RPCError was raised before.' |                 'and a password only if an RPCError was raised before.' | ||||||
|             ) |             ) | ||||||
| 
 | 
 | ||||||
|         self._set_connected_and_authorized() |         if result: | ||||||
|         return result.user |             self._set_connected_and_authorized() | ||||||
|  |             return result.user | ||||||
|  |         return result | ||||||
| 
 | 
 | ||||||
|     async def sign_up(self, code, first_name, last_name=''): |     async def sign_up(self, code, first_name, last_name=''): | ||||||
|         """Signs up to Telegram. Make sure you sent a code request first!""" |         """Signs up to Telegram. Make sure you sent a code request first!""" | ||||||
|  | @ -182,8 +187,10 @@ class TelegramClient(TelegramBareClient): | ||||||
|             last_name=last_name |             last_name=last_name | ||||||
|         )) |         )) | ||||||
| 
 | 
 | ||||||
|         self._set_connected_and_authorized() |         if result: | ||||||
|         return result.user |             self._set_connected_and_authorized() | ||||||
|  |             return result.user | ||||||
|  |         return result | ||||||
| 
 | 
 | ||||||
|     async def log_out(self): |     async def log_out(self): | ||||||
|         """Logs out and deletes the current session. |         """Logs out and deletes the current session. | ||||||
|  | @ -239,7 +246,7 @@ class TelegramClient(TelegramBareClient): | ||||||
|                 offset_peer=offset_peer, |                 offset_peer=offset_peer, | ||||||
|                 limit=need if need < float('inf') else 0 |                 limit=need if need < float('inf') else 0 | ||||||
|             )) |             )) | ||||||
|             if not r.dialogs: |             if not r or not r.dialogs: | ||||||
|                 break |                 break | ||||||
| 
 | 
 | ||||||
|             for d in r.dialogs: |             for d in r.dialogs: | ||||||
|  | @ -288,10 +295,12 @@ class TelegramClient(TelegramBareClient): | ||||||
|         :return List[telethon.tl.custom.Draft]: A list of open drafts |         :return List[telethon.tl.custom.Draft]: A list of open drafts | ||||||
|         """ |         """ | ||||||
|         response = await self(GetAllDraftsRequest()) |         response = await self(GetAllDraftsRequest()) | ||||||
|         self.session.process_entities(response) |         if response: | ||||||
|         self.session.generate_sequence(response.seq) |             self.session.process_entities(response) | ||||||
|         drafts = [Draft._from_update(self, u) for u in response.updates] |             self.session.generate_sequence(response.seq) | ||||||
|         return drafts |             drafts = [Draft._from_update(self, u) for u in response.updates] | ||||||
|  |             return drafts | ||||||
|  |         return response | ||||||
| 
 | 
 | ||||||
|     async def send_message(self, |     async def send_message(self, | ||||||
|                            entity, |                            entity, | ||||||
|  | @ -313,6 +322,9 @@ class TelegramClient(TelegramBareClient): | ||||||
|             reply_to_msg_id=self._get_reply_to(reply_to) |             reply_to_msg_id=self._get_reply_to(reply_to) | ||||||
|         ) |         ) | ||||||
|         result = await self(request) |         result = await self(request) | ||||||
|  |         if not result: | ||||||
|  |             return result | ||||||
|  | 
 | ||||||
|         if isinstance(result, UpdateShortSentMessage): |         if isinstance(result, UpdateShortSentMessage): | ||||||
|             return Message( |             return Message( | ||||||
|                 id=result.id, |                 id=result.id, | ||||||
|  | @ -407,6 +419,8 @@ class TelegramClient(TelegramBareClient): | ||||||
|             min_id=min_id, |             min_id=min_id, | ||||||
|             add_offset=add_offset |             add_offset=add_offset | ||||||
|         )) |         )) | ||||||
|  |         if not result: | ||||||
|  |             return result | ||||||
| 
 | 
 | ||||||
|         # The result may be a messages slice (not all messages were retrieved) |         # The result may be a messages slice (not all messages were retrieved) | ||||||
|         # or simply a messages TLObject. In the later case, no "count" |         # or simply a messages TLObject. In the later case, no "count" | ||||||
|  |  | ||||||
|  | @ -11,6 +11,12 @@ class MessageContainer(TLObject): | ||||||
|         self.content_related = False |         self.content_related = False | ||||||
|         self.messages = messages |         self.messages = messages | ||||||
| 
 | 
 | ||||||
|  |     def to_dict(self, recursive=True): | ||||||
|  |         return { | ||||||
|  |             'content_related': self.content_related, | ||||||
|  |             'messages': self.messages, | ||||||
|  |         } | ||||||
|  | 
 | ||||||
|     def to_bytes(self): |     def to_bytes(self): | ||||||
|         return struct.pack( |         return struct.pack( | ||||||
|             '<Ii', MessageContainer.CONSTRUCTOR_ID, len(self.messages) |             '<Ii', MessageContainer.CONSTRUCTOR_ID, len(self.messages) | ||||||
|  | @ -25,3 +31,9 @@ class MessageContainer(TLObject): | ||||||
|             inner_sequence = reader.read_int() |             inner_sequence = reader.read_int() | ||||||
|             inner_length = reader.read_int() |             inner_length = reader.read_int() | ||||||
|             yield inner_msg_id, inner_sequence, inner_length |             yield inner_msg_id, inner_sequence, inner_length | ||||||
|  | 
 | ||||||
|  |     def __str__(self): | ||||||
|  |         return TLObject.pretty_format(self) | ||||||
|  | 
 | ||||||
|  |     def stringify(self): | ||||||
|  |         return TLObject.pretty_format(self, indent=0) | ||||||
|  |  | ||||||
|  | @ -1,4 +1,5 @@ | ||||||
| import struct | import struct | ||||||
|  | import logging | ||||||
| 
 | 
 | ||||||
| from . import TLObject, GzipPacked | from . import TLObject, GzipPacked | ||||||
| 
 | 
 | ||||||
|  | @ -11,7 +12,23 @@ class TLMessage(TLObject): | ||||||
|         self.msg_id = session.get_new_msg_id() |         self.msg_id = session.get_new_msg_id() | ||||||
|         self.seq_no = session.generate_sequence(request.content_related) |         self.seq_no = session.generate_sequence(request.content_related) | ||||||
|         self.request = request |         self.request = request | ||||||
|  |         self.container_msg_id = None | ||||||
|  |         logging.getLogger(__name__).debug(self) | ||||||
|  | 
 | ||||||
|  |     def to_dict(self, recursive=True): | ||||||
|  |         return { | ||||||
|  |             'msg_id': self.msg_id, | ||||||
|  |             'seq_no': self.seq_no, | ||||||
|  |             'request': self.request, | ||||||
|  |             'container_msg_id': self.container_msg_id, | ||||||
|  |         } | ||||||
| 
 | 
 | ||||||
|     def to_bytes(self): |     def to_bytes(self): | ||||||
|         body = GzipPacked.gzip_if_smaller(self.request) |         body = GzipPacked.gzip_if_smaller(self.request) | ||||||
|         return struct.pack('<qii', self.msg_id, self.seq_no, len(body)) + body |         return struct.pack('<qii', self.msg_id, self.seq_no, len(body)) + body | ||||||
|  | 
 | ||||||
|  |     def __str__(self): | ||||||
|  |         return TLObject.pretty_format(self) | ||||||
|  | 
 | ||||||
|  |     def stringify(self): | ||||||
|  |         return TLObject.pretty_format(self, indent=0) | ||||||
|  |  | ||||||
|  | @ -1,12 +1,10 @@ | ||||||
| from datetime import datetime | from datetime import datetime | ||||||
| from threading import Event |  | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| class TLObject: | class TLObject: | ||||||
|     def __init__(self): |     def __init__(self): | ||||||
|         self.request_msg_id = 0  # Long |         self.request_msg_id = 0  # Long | ||||||
| 
 |         self.confirm_received = None | ||||||
|         self.confirm_received = Event() |  | ||||||
|         self.rpc_error = None |         self.rpc_error = None | ||||||
| 
 | 
 | ||||||
|         # These should be overrode |         # These should be overrode | ||||||
|  |  | ||||||
|  | @ -1,8 +1,8 @@ | ||||||
| import logging | import logging | ||||||
| import pickle | import pickle | ||||||
|  | import asyncio | ||||||
| from collections import deque | from collections import deque | ||||||
| from datetime import datetime | from datetime import datetime | ||||||
| from threading import RLock, Event, Thread |  | ||||||
| 
 | 
 | ||||||
| from .tl import types as tl | from .tl import types as tl | ||||||
| 
 | 
 | ||||||
|  | @ -13,177 +13,72 @@ class UpdateState: | ||||||
|     """ |     """ | ||||||
|     WORKER_POLL_TIMEOUT = 5.0  # Avoid waiting forever on the workers |     WORKER_POLL_TIMEOUT = 5.0  # Avoid waiting forever on the workers | ||||||
| 
 | 
 | ||||||
|     def __init__(self, workers=None): |     def __init__(self, loop=None): | ||||||
|         """ |  | ||||||
|         :param workers: This integer parameter has three possible cases: |  | ||||||
|           workers is None: Updates will *not* be stored on self. |  | ||||||
|           workers = 0: Another thread is responsible for calling self.poll() |  | ||||||
|           workers > 0: 'workers' background threads will be spawned, any |  | ||||||
|                        any of them will invoke all the self.handlers. |  | ||||||
|         """ |  | ||||||
|         self._workers = workers |  | ||||||
|         self._worker_threads = [] |  | ||||||
| 
 |  | ||||||
|         self.handlers = [] |         self.handlers = [] | ||||||
|         self._updates_lock = RLock() |  | ||||||
|         self._updates_available = Event() |  | ||||||
|         self._updates = deque() |  | ||||||
|         self._latest_updates = deque(maxlen=10) |         self._latest_updates = deque(maxlen=10) | ||||||
|  |         self._loop = loop if loop else asyncio.get_event_loop() | ||||||
| 
 | 
 | ||||||
|         self._logger = logging.getLogger(__name__) |         self._logger = logging.getLogger(__name__) | ||||||
| 
 | 
 | ||||||
|         # https://core.telegram.org/api/updates |         # https://core.telegram.org/api/updates | ||||||
|         self._state = tl.updates.State(0, 0, datetime.now(), 0, 0) |         self._state = tl.updates.State(0, 0, datetime.now(), 0, 0) | ||||||
| 
 | 
 | ||||||
|     def can_poll(self): |     def handle_update(self, update): | ||||||
|         """Returns True if a call to .poll() won't lock""" |         for handler in self.handlers: | ||||||
|         return self._updates_available.is_set() |             asyncio.ensure_future(handler(update), loop=self._loop) | ||||||
| 
 |  | ||||||
|     def poll(self, timeout=None): |  | ||||||
|         """Polls an update or blocks until an update object is available. |  | ||||||
|            If 'timeout is not None', it should be a floating point value, |  | ||||||
|            and the method will 'return None' if waiting times out. |  | ||||||
|         """ |  | ||||||
|         if not self._updates_available.wait(timeout=timeout): |  | ||||||
|             return |  | ||||||
| 
 |  | ||||||
|         with self._updates_lock: |  | ||||||
|             if not self._updates_available.is_set(): |  | ||||||
|                 return |  | ||||||
| 
 |  | ||||||
|             update = self._updates.popleft() |  | ||||||
|             if not self._updates: |  | ||||||
|                 self._updates_available.clear() |  | ||||||
| 
 |  | ||||||
|         if isinstance(update, Exception): |  | ||||||
|             raise update  # Some error was set through (surely StopIteration) |  | ||||||
| 
 |  | ||||||
|         return update |  | ||||||
| 
 |  | ||||||
|     def get_workers(self): |  | ||||||
|         return self._workers |  | ||||||
| 
 |  | ||||||
|     def set_workers(self, n): |  | ||||||
|         """Changes the number of workers running. |  | ||||||
|            If 'n is None', clears all pending updates from memory. |  | ||||||
|         """ |  | ||||||
|         self.stop_workers() |  | ||||||
|         self._workers = n |  | ||||||
|         if n is None: |  | ||||||
|             self._updates.clear() |  | ||||||
|         else: |  | ||||||
|             self.setup_workers() |  | ||||||
| 
 |  | ||||||
|     workers = property(fget=get_workers, fset=set_workers) |  | ||||||
| 
 |  | ||||||
|     def stop_workers(self): |  | ||||||
|         """Raises "StopIterationException" on the worker threads to stop them, |  | ||||||
|            and also clears all of them off the list |  | ||||||
|         """ |  | ||||||
|         if self._workers: |  | ||||||
|             with self._updates_lock: |  | ||||||
|                 # Insert at the beginning so the very next poll causes an error |  | ||||||
|                 # on all the worker threads |  | ||||||
|                 # TODO Should this reset the pts and such? |  | ||||||
|                 for _ in range(self._workers): |  | ||||||
|                     self._updates.appendleft(StopIteration()) |  | ||||||
|                 self._updates_available.set() |  | ||||||
| 
 |  | ||||||
|         for t in self._worker_threads: |  | ||||||
|             t.join() |  | ||||||
| 
 |  | ||||||
|         self._worker_threads.clear() |  | ||||||
| 
 |  | ||||||
|     def setup_workers(self): |  | ||||||
|         if self._worker_threads or not self._workers: |  | ||||||
|             # There already are workers, or workers is None or 0. Do nothing. |  | ||||||
|             return |  | ||||||
| 
 |  | ||||||
|         for i in range(self._workers): |  | ||||||
|             thread = Thread( |  | ||||||
|                 target=UpdateState._worker_loop, |  | ||||||
|                 name='UpdateWorker{}'.format(i), |  | ||||||
|                 daemon=True, |  | ||||||
|                 args=(self, i) |  | ||||||
|             ) |  | ||||||
|             self._worker_threads.append(thread) |  | ||||||
|             thread.start() |  | ||||||
| 
 |  | ||||||
|     def _worker_loop(self, wid): |  | ||||||
|         while True: |  | ||||||
|             try: |  | ||||||
|                 update = self.poll(timeout=UpdateState.WORKER_POLL_TIMEOUT) |  | ||||||
|                 # TODO Maybe people can add different handlers per update type |  | ||||||
|                 if update: |  | ||||||
|                     for handler in self.handlers: |  | ||||||
|                         handler(update) |  | ||||||
|             except StopIteration: |  | ||||||
|                 break |  | ||||||
|             except Exception as e: |  | ||||||
|                 # We don't want to crash a worker thread due to any reason |  | ||||||
|                 self._logger.debug( |  | ||||||
|                     '[ERROR] Unhandled exception on worker {}'.format(wid), e |  | ||||||
|                 ) |  | ||||||
| 
 | 
 | ||||||
|     def process(self, update): |     def process(self, update): | ||||||
|         """Processes an update object. This method is normally called by |         """Processes an update object. This method is normally called by | ||||||
|            the library itself. |            the library itself. | ||||||
|         """ |         """ | ||||||
|         if self._workers is None: |         if isinstance(update, tl.updates.State): | ||||||
|             return  # No processing needs to be done if nobody's working |             self._state = update | ||||||
|  |             return  # Nothing else to be done | ||||||
| 
 | 
 | ||||||
|         with self._updates_lock: |         pts = getattr(update, 'pts', self._state.pts) | ||||||
|             if isinstance(update, tl.updates.State): |         if hasattr(update, 'pts') and pts <= self._state.pts: | ||||||
|                 self._state = update |             return  # We already handled this update | ||||||
|                 return  # Nothing else to be done |  | ||||||
| 
 | 
 | ||||||
|             pts = getattr(update, 'pts', self._state.pts) |         self._state.pts = pts | ||||||
|             if hasattr(update, 'pts') and pts <= self._state.pts: |  | ||||||
|                 return  # We already handled this update |  | ||||||
| 
 | 
 | ||||||
|             self._state.pts = pts |         # TODO There must be a better way to handle updates rather than | ||||||
|  |         # keeping a queue with the latest updates only, and handling | ||||||
|  |         # the 'pts' correctly should be enough. However some updates | ||||||
|  |         # like UpdateUserStatus (even inside UpdateShort) will be called | ||||||
|  |         # repeatedly very often if invoking anything inside an update | ||||||
|  |         # handler. TODO Figure out why. | ||||||
|  |         """ | ||||||
|  |         client = TelegramClient('anon', api_id, api_hash, update_workers=1) | ||||||
|  |         client.connect() | ||||||
|  |         def handle(u): | ||||||
|  |             client.get_me() | ||||||
|  |         client.add_update_handler(handle) | ||||||
|  |         input('Enter to exit.') | ||||||
|  |         """ | ||||||
|  |         data = pickle.dumps(update.to_dict()) | ||||||
|  |         if data in self._latest_updates: | ||||||
|  |             return  # Duplicated too | ||||||
| 
 | 
 | ||||||
|             # TODO There must be a better way to handle updates rather than |         self._latest_updates.append(data) | ||||||
|             # keeping a queue with the latest updates only, and handling |  | ||||||
|             # the 'pts' correctly should be enough. However some updates |  | ||||||
|             # like UpdateUserStatus (even inside UpdateShort) will be called |  | ||||||
|             # repeatedly very often if invoking anything inside an update |  | ||||||
|             # handler. TODO Figure out why. |  | ||||||
|             """ |  | ||||||
|             client = TelegramClient('anon', api_id, api_hash, update_workers=1) |  | ||||||
|             client.connect() |  | ||||||
|             def handle(u): |  | ||||||
|                 client.get_me() |  | ||||||
|             client.add_update_handler(handle) |  | ||||||
|             input('Enter to exit.') |  | ||||||
|             """ |  | ||||||
|             data = pickle.dumps(update.to_dict()) |  | ||||||
|             if data in self._latest_updates: |  | ||||||
|                 return  # Duplicated too |  | ||||||
| 
 | 
 | ||||||
|             self._latest_updates.append(data) |         if type(update).SUBCLASS_OF_ID == 0x8af52aac:  # crc32(b'Updates') | ||||||
|  |             # Expand "Updates" into "Update", and pass these to callbacks. | ||||||
|  |             # Since .users and .chats have already been processed, we | ||||||
|  |             # don't need to care about those either. | ||||||
|  |             if isinstance(update, tl.UpdateShort): | ||||||
|  |                 self.handle_update(update.update) | ||||||
| 
 | 
 | ||||||
|             if type(update).SUBCLASS_OF_ID == 0x8af52aac:  # crc32(b'Updates') |             elif isinstance(update, (tl.Updates, tl.UpdatesCombined)): | ||||||
|                 # Expand "Updates" into "Update", and pass these to callbacks. |                 for upd in update.updates: | ||||||
|                 # Since .users and .chats have already been processed, we |                     self.handle_update(upd) | ||||||
|                 # don't need to care about those either. |  | ||||||
|                 if isinstance(update, tl.UpdateShort): |  | ||||||
|                     self._updates.append(update.update) |  | ||||||
|                     self._updates_available.set() |  | ||||||
| 
 | 
 | ||||||
|                 elif isinstance(update, (tl.Updates, tl.UpdatesCombined)): |             elif not isinstance(update, tl.UpdatesTooLong): | ||||||
|                     self._updates.extend(update.updates) |                 # TODO Handle "Updates too long" | ||||||
|                     self._updates_available.set() |                 self.handle_update(update) | ||||||
| 
 | 
 | ||||||
|                 elif not isinstance(update, tl.UpdatesTooLong): |         elif type(update).SUBCLASS_OF_ID == 0x9f89304e:  # crc32(b'Update') | ||||||
|                     # TODO Handle "Updates too long" |             self.handle_update(update) | ||||||
|                     self._updates.append(update) |         else: | ||||||
|                     self._updates_available.set() |             self._logger.debug('Ignoring "update" of type {}'.format( | ||||||
| 
 |                 type(update).__name__) | ||||||
|             elif type(update).SUBCLASS_OF_ID == 0x9f89304e:  # crc32(b'Update') |             ) | ||||||
|                 self._updates.append(update) |  | ||||||
|                 self._updates_available.set() |  | ||||||
|             else: |  | ||||||
|                 self._logger.debug('Ignoring "update" of type {}'.format( |  | ||||||
|                     type(update).__name__) |  | ||||||
|                 ) |  | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue
	
	Block a user