diff --git a/errors.py b/errors.py index cdbe4017..c493a1da 100644 --- a/errors.py +++ b/errors.py @@ -1,6 +1,12 @@ import re +class ReadCancelledError(Exception): + """Occurs when a read operation was cancelled""" + def __init__(self): + super().__init__(self, 'You must run `python3 tl_generator.py` first. #ReadTheDocs!') + + class TLGeneratorNotRan(Exception): """Occurs when you should've ran `tl_generator.py`, but you haven't""" def __init__(self): diff --git a/main.py b/main.py index b2a55694..4cc390ec 100755 --- a/main.py +++ b/main.py @@ -78,4 +78,5 @@ if __name__ == '__main__': else: client.send_message(input_peer, msg, markdown=True, no_web_page=True) - print('Thanks for trying the interactive example! Exiting.') + print('Thanks for trying the interactive example! Exiting...') + client.disconnect() diff --git a/network/mtproto_sender.py b/network/mtproto_sender.py index d8a7848c..9e4bfef0 100755 --- a/network/mtproto_sender.py +++ b/network/mtproto_sender.py @@ -3,6 +3,7 @@ import gzip from errors import * from time import sleep +from threading import Thread import utils from crypto import AES @@ -20,13 +21,27 @@ class MtProtoSender: self.need_confirmation = [] # Message IDs that need confirmation self.on_update_handlers = [] + # Set up updates thread + self.updates_thread = Thread(target=self.updates_thread_method, name='Updates thread') + self.updates_thread_running = True + self.updates_thread_paused = True + + self.updates_thread.start() + + def disconnect(self): + """Disconnects and **stops all the running threads**""" + self.updates_thread_running = False + self.transport.cancel_receive() + self.transport.close() + def add_update_handler(self, handler): - """Adds an update handler (a method with one argument, the received TLObject) - that is fired when there are updates available""" + """Adds an update handler (a method with one argument, the received + TLObject) that is fired when there are updates available""" self.on_update_handlers.append(handler) def generate_sequence(self, confirmed): - """Generates the next sequence number, based on whether it was confirmed yet or not""" + """Generates the next sequence number, based on whether it + was confirmed yet or not""" if confirmed: result = self.session.sequence * 2 + 1 self.session.sequence += 1 @@ -37,26 +52,33 @@ class MtProtoSender: # region Send and receive def send(self, request): - """Sends the specified MTProtoRequest, previously sending any message which needed confirmation""" + """Sends the specified MTProtoRequest, previously sending any message + which needed confirmation. This also pauses the updates thread""" - # First check if any message needs confirmation, if this is the case, send an "AckRequest" + # Pause the updates thread: we cannot use self.transport at the same time! + self.pause_updates_thread() + + # If any message needs confirmation send an AckRequest first if self.need_confirmation: - msgs_ack = MsgsAck(self.need_confirmation) - with BinaryWriter() as writer: - msgs_ack.on_send(writer) - self.send_packet(writer.get_bytes(), msgs_ack) - del self.need_confirmation[:] + msgs_ack = MsgsAck(self.need_confirmation[:]) + del self.need_confirmation[:] + self.send(msgs_ack) - # Then send our packed request + # Finally send our packed request with BinaryWriter() as writer: request.on_send(writer) self.send_packet(writer.get_bytes(), request) # And update the saved session self.session.save() + # Don't resume the updates thread yet, + # since every send() is preceded by a receive() def receive(self, request): - """Receives the specified MTProtoRequest ("fills in it" the received data)""" + """Receives the specified MTProtoRequest ("fills in it" + the received data). This also restores the updates thread""" + + # Don't stop receiving until we get the request we wanted while not request.confirm_received: seq, body = self.transport.receive() message, remote_msg_id, remote_sequence = self.decode_msg(body) @@ -64,6 +86,9 @@ class MtProtoSender: with BinaryReader(message) as reader: self.process_msg(remote_msg_id, remote_sequence, reader, request) + # Once we have our request, restore the updates thread + self.restore_updates_thread() + # endregion # region Low level processing @@ -120,31 +145,34 @@ class MtProtoSender: return message, remote_msg_id, remote_sequence - def process_msg(self, msg_id, sequence, reader, request): - """Processes and handles a Telegram message""" + def process_msg(self, msg_id, sequence, reader, request, only_updates=False): + """Processes and handles a Telegram message. Optionally, this + function will only check for updates (hence the request can be None)""" + # TODO Check salt, session_id and sequence_number self.need_confirmation.append(msg_id) code = reader.read_int(signed=False) reader.seek(-4) - # The following codes are "parsed manually" - if code == 0xf35c6d01: # rpc_result - return self.handle_rpc_result(msg_id, sequence, reader, request) - elif code == 0x73f1f8dc: # msg_container - return self.handle_container(msg_id, sequence, reader, request) - elif code == 0x3072cfa1: # gzip_packed - return self.handle_gzip_packed(msg_id, sequence, reader, request) - elif code == 0xedab447b: # bad_server_salt - return self.handle_bad_server_salt(msg_id, sequence, reader, request) - elif code == 0xa7eff811: # bad_msg_notification - return self.handle_bad_msg_notification(msg_id, sequence, reader) - else: - # If the code is not parsed manually, then it was parsed by the code generator! - # In this case, we will simply treat the incoming TLObject as an Update, - # if we can first find a matching TLObject - if code in tlobjects.keys(): - return self.handle_update(msg_id, sequence, reader) + # The following codes are "parsed manually" (and do not refer to an update) + if not only_updates: + if code == 0xf35c6d01: # rpc_result + return self.handle_rpc_result(msg_id, sequence, reader, request) + if code == 0x73f1f8dc: # msg_container + return self.handle_container(msg_id, sequence, reader, request) + if code == 0x3072cfa1: # gzip_packed + return self.handle_gzip_packed(msg_id, sequence, reader, request) + if code == 0xedab447b: # bad_server_salt + return self.handle_bad_server_salt(msg_id, sequence, reader, request) + if code == 0xa7eff811: # bad_msg_notification + return self.handle_bad_msg_notification(msg_id, sequence, reader) + + # If the code is not parsed manually, then it was parsed by the code generator! + # In this case, we will simply treat the incoming TLObject as an Update, + # if we can first find a matching TLObject + if code in tlobjects.keys(): + return self.handle_update(msg_id, sequence, reader) print('Unknown message: {}'.format(hex(code))) return False @@ -196,18 +224,22 @@ class MtProtoSender: error_code = reader.read_int() raise BadMessageError(error_code) - def handle_rpc_result(self, msg_id, sequence, reader, mtproto_request): + def handle_rpc_result(self, msg_id, sequence, reader, request): code = reader.read_int(signed=False) request_id = reader.read_long(signed=False) inner_code = reader.read_int(signed=False) - if request_id == mtproto_request.msg_id: - mtproto_request.confirm_received = True + if not request: + raise ValueError('Cannot handle RPC results if no request was specified. ' + 'This should only happen when the updates thread does not work properly.') + + if request_id == request.msg_id: + request.confirm_received = True if inner_code == 0x2144ca19: # RPC Error error = RPCError(code=reader.read_int(), message=reader.tgread_string()) if error.must_resend: - mtproto_request.confirm_received = False + request.confirm_received = False if error.message.startswith('FLOOD_WAIT_'): print('Should wait {}s. Sleeping until then.'.format(error.additional_data)) @@ -223,11 +255,11 @@ class MtProtoSender: if inner_code == 0x3072cfa1: # GZip packed unpacked_data = gzip.decompress(reader.tgread_bytes()) with BinaryReader(unpacked_data) as compressed_reader: - mtproto_request.on_response(compressed_reader) + request.on_response(compressed_reader) else: reader.seek(-4) - mtproto_request.on_response(reader) + request.on_response(reader) def handle_gzip_packed(self, msg_id, sequence, reader, mtproto_request): code = reader.read_int(signed=False) @@ -238,4 +270,30 @@ class MtProtoSender: self.process_msg(msg_id, sequence, compressed_reader, mtproto_request) # endregion - pass + + def pause_updates_thread(self): + """Pauses the updates thread and sleeps until it's safe to continue""" + if not self.updates_thread_paused: + self.updates_thread_paused = True + self.transport.cancel_receive() + + def restore_updates_thread(self): + """Restores the updates thread""" + self.updates_thread_paused = False + + # TODO avoid, if possible using sleeps; Use thread locks instead + def updates_thread_method(self): + """This method will run until specified and listen for incoming updates""" + while self.updates_thread_running: + if not self.updates_thread_paused: + try: + seq, body = self.transport.receive() + message, remote_msg_id, remote_sequence = self.decode_msg(body) + + with BinaryReader(message) as reader: + self.process_msg(remote_msg_id, remote_sequence, reader, request=None, only_updates=True) + + except ReadCancelledError: + pass + + sleep(self.transport.get_client_delay()) diff --git a/network/tcp_client.py b/network/tcp_client.py index 7287f58a..034662f8 100755 --- a/network/tcp_client.py +++ b/network/tcp_client.py @@ -1,5 +1,8 @@ # Python rough implementation of a C# TCP client import socket +import time + +from errors import ReadCancelledError from utils import BinaryWriter @@ -7,16 +10,20 @@ class TcpClient: def __init__(self): self.connected = False self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + self.cancelled = False # Has the read operation been cancelled? + self.delay = 0.1 # Read delay when there was no data available def connect(self, ip, port): """Connects to the specified IP and port number""" self.socket.connect((ip, port)) self.connected = True + self.socket.setblocking(False) def close(self): """Closes the connection""" self.socket.close() self.connected = False + self.socket.setblocking(True) def write(self, data): """Writes (sends) the specified bytes to the connected peer""" @@ -24,12 +31,36 @@ class TcpClient: def read(self, buffer_size): """Reads (receives) the specified bytes from the connected peer""" + self.cancelled = False # Ensure it is not cancelled at first + with BinaryWriter() as writer: - while writer.written_count < buffer_size: - # When receiving from the socket, we may not receive all the data at once - # This is why we need to keep checking to make sure that we receive it all - left_count = buffer_size - writer.written_count - partial = self.socket.recv(left_count) - writer.write(partial) + while writer.written_count < buffer_size and not self.cancelled: + try: + # When receiving from the socket, we may not receive all the data at once + # This is why we need to keep checking to make sure that we receive it all + left_count = buffer_size - writer.written_count + partial = self.socket.recv(left_count) + writer.write(partial) + + except BlockingIOError: + # There was no data available for us to read. Sleep a bit + time.sleep(self.delay) + + # If the operation was cancelled *but* data was read, + # this will result on data loss so raise an exception + # TODO this could be solved by using an internal FIFO buffer (first in, first out) + if self.cancelled: + if writer.written_count == 0: + raise ReadCancelledError() + else: + raise NotImplementedError('The read operation was cancelled when some data ' + 'was already read. This has not yet implemented ' + 'an internal buffer, so cannot continue.') return writer.get_bytes() + + def cancel_read(self): + """Cancels the read operation if it was blocking and stops + the current thread until it's cancelled""" + self.cancelled = True + time.sleep(self.delay) diff --git a/network/tcp_transport.py b/network/tcp_transport.py index 82719684..f53aaebe 100755 --- a/network/tcp_transport.py +++ b/network/tcp_transport.py @@ -58,3 +58,12 @@ class TcpTransport: def close(self): if self.tcp_client.connected: self.tcp_client.close() + + def cancel_receive(self): + """Cancels (stops) trying to receive from the remote peer and + stops the current thread until it's cancelled""" + self.tcp_client.cancel_read() + + def get_client_delay(self): + """Gets the client read delay""" + return self.tcp_client.delay diff --git a/tl/__init__.py b/tl/__init__.py index 0cd67cc1..7f1de756 100755 --- a/tl/__init__.py +++ b/tl/__init__.py @@ -1,9 +1,4 @@ -try: - from .all_tlobjects import tlobjects - from .session import Session - from .mtproto_request import MTProtoRequest - from .telegram_client import TelegramClient - -except ImportError: - import errors - raise errors.TLGeneratorNotRan() +from .all_tlobjects import tlobjects +from .session import Session +from .mtproto_request import MTProtoRequest +from .telegram_client import TelegramClient diff --git a/tl/mtproto_request.py b/tl/mtproto_request.py index f642af79..08784d38 100755 --- a/tl/mtproto_request.py +++ b/tl/mtproto_request.py @@ -15,6 +15,7 @@ class MTProtoRequest: self.confirm_received = False # These should be overrode + self.constructor_id = 0 self.confirmed = False self.responded = False diff --git a/tl/telegram_client.py b/tl/telegram_client.py index 241eddc3..d06bc5d3 100644 --- a/tl/telegram_client.py +++ b/tl/telegram_client.py @@ -9,6 +9,7 @@ from errors import * from network import MtProtoSender, TcpTransport from parser.markdown_parser import parse_message_entities +# For sending and receiving requests from tl import Session from tl.types import PeerUser, PeerChat, PeerChannel, InputPeerUser, InputPeerChat, InputPeerChannel, InputPeerEmpty from tl.functions import InvokeWithLayerRequest, InitConnectionRequest @@ -16,6 +17,9 @@ from tl.functions.help import GetConfigRequest from tl.functions.auth import SendCodeRequest, SignInRequest from tl.functions.messages import GetDialogsRequest, GetHistoryRequest, SendMessageRequest +# For working with updates +from tl.types import UpdateShortMessage + class TelegramClient: @@ -90,6 +94,11 @@ class TelegramClient: self.connect(reconnect=True) + def disconnect(self): + """Disconnects from the Telegram server **and pauses all the spawned threads**""" + if self.sender: + self.sender.disconnect() + # endregion # region Telegram requests functions @@ -261,6 +270,10 @@ class TelegramClient: def on_update(self, tlobject): """This method is fired when there are updates from Telegram. Add your own implementation below, or simply override it!""" - print('We have an update: {}'.format(str(tlobject))) + + # Only show incoming messages + if type(tlobject) is UpdateShortMessage: + if not tlobject.out: + print('> User with ID {} said "{}"'.format(tlobject.user_id, tlobject.message)) # endregion diff --git a/tl_generator.py b/tl_generator.py index bbaa2177..fee29d00 100755 --- a/tl_generator.py +++ b/tl_generator.py @@ -54,6 +54,7 @@ def generate_tlobjects(scheme_file): with open(filename, 'w', encoding='utf-8') as file: # Let's build the source code! with SourceBuilder(file) as builder: + # Both types and functions inherit from MTProtoRequest so they all can be sent builder.writeln('from tl.mtproto_request import MTProtoRequest') builder.writeln() builder.writeln() @@ -105,6 +106,9 @@ def generate_tlobjects(scheme_file): builder.writeln('self.result = None') builder.writeln('self.confirmed = True # Confirmed by default') + # Create an attribute that stores the TLObject's constructor ID + builder.writeln('self.constructor_id = {}'.format(hex(tlobject.id))) + # Set the arguments if args: # Leave an empty line if there are any args @@ -115,7 +119,7 @@ def generate_tlobjects(scheme_file): # Write the on_send(self, writer) function builder.writeln('def on_send(self, writer):') - builder.writeln("writer.write_int({}, signed=False) # {}'s constructor ID" + builder.writeln('writer.write_int(self.constructor_id, signed=False)' .format(hex(tlobject.id), tlobject.name)) for arg in tlobject.args: