diff --git a/telethon/extensions/tcp_client.py b/telethon/extensions/tcp_client.py index 02c26996..795b714d 100644 --- a/telethon/extensions/tcp_client.py +++ b/telethon/extensions/tcp_client.py @@ -65,25 +65,26 @@ class TcpClient: def write(self, data): """Writes (sends) the specified bytes to the connected peer""" - # Ensure that only one thread can send data at once - with self._lock: - try: - view = memoryview(data) - total_sent, total = 0, len(data) - while total_sent < total: - try: - sent = self._socket.send(view[total_sent:]) - if sent == 0: - self.close() - raise ConnectionResetError( - 'The server has closed the connection.') - total_sent += sent + # TODO Check whether the code using this has multiple threads calling + # .write() on the very same socket. If so, have two locks, one for + # .write() and another for .read(). + try: + view = memoryview(data) + total_sent, total = 0, len(data) + while total_sent < total: + try: + sent = self._socket.send(view[total_sent:]) + if sent == 0: + self.close() + raise ConnectionResetError( + 'The server has closed the connection.') + total_sent += sent - except BlockingIOError: - time.sleep(self.delay) - except BrokenPipeError: - self.close() - raise + except BlockingIOError: + time.sleep(self.delay) + except BrokenPipeError: + self.close() + raise def read(self, size, timeout=timedelta(seconds=5)): """Reads (receives) a whole block of 'size bytes @@ -95,47 +96,45 @@ class TcpClient: operation. Set to None for no timeout """ - # Ensure that only one thread can receive data at once - with self._lock: - # Ensure it is not cancelled at first, so we can enter the loop - self.cancelled.clear() + # Ensure it is not cancelled at first, so we can enter the loop + self.cancelled.clear() - # Set the starting time so we can - # calculate whether the timeout should fire - start_time = datetime.now() if timeout is not None else None + # Set the starting time so we can + # calculate whether the timeout should fire + start_time = datetime.now() if timeout is not None else None - with BufferedWriter(BytesIO(), buffer_size=size) as buffer: - bytes_left = size - while bytes_left != 0: - # Only do cancel if no data was read yet - # Otherwise, carry on reading and finish - if self.cancelled.is_set() and bytes_left == size: - raise ReadCancelledError() + with BufferedWriter(BytesIO(), buffer_size=size) as buffer: + bytes_left = size + while bytes_left != 0: + # Only do cancel if no data was read yet + # Otherwise, carry on reading and finish + if self.cancelled.is_set() and bytes_left == size: + raise ReadCancelledError() - try: - partial = self._socket.recv(bytes_left) - if len(partial) == 0: - self.close() - raise ConnectionResetError( - 'The server has closed the connection.') + try: + partial = self._socket.recv(bytes_left) + if len(partial) == 0: + self.close() + raise ConnectionResetError( + 'The server has closed the connection.') - buffer.write(partial) - bytes_left -= len(partial) + buffer.write(partial) + bytes_left -= len(partial) - except BlockingIOError as error: - # No data available yet, sleep a bit - time.sleep(self.delay) + except BlockingIOError as error: + # No data available yet, sleep a bit + time.sleep(self.delay) - # Check if the timeout finished - if timeout is not None: - time_passed = datetime.now() - start_time - if time_passed > timeout: - raise TimeoutError( - 'The read operation exceeded the timeout.') from error + # Check if the timeout finished + if timeout is not None: + time_passed = datetime.now() - start_time + if time_passed > timeout: + raise TimeoutError( + 'The read operation exceeded the timeout.') from error - # If everything went fine, return the read bytes - buffer.flush() - return buffer.raw.getvalue() + # If everything went fine, return the read bytes + buffer.flush() + return buffer.raw.getvalue() def cancel_read(self): """Cancels the read operation IF it hasn't yet