diff --git a/telethon/extensions/tcp_client.py b/telethon/extensions/tcp_client.py index 41cec3e5..23268e26 100644 --- a/telethon/extensions/tcp_client.py +++ b/telethon/extensions/tcp_client.py @@ -6,7 +6,6 @@ import logging from datetime import timedelta from io import BytesIO, BufferedWriter -MAX_TIMEOUT = 15 # in seconds CONN_RESET_ERRNOS = { errno.EBADF, errno.ENOTSOCK, errno.ENETUNREACH, errno.EINVAL, errno.ENOTCONN, errno.EHOSTUNREACH, @@ -24,6 +23,8 @@ class TcpClient: self._socket = None self._loop = loop if loop else asyncio.get_event_loop() self._logger = logging.getLogger(__name__) + self._closed = asyncio.Event(loop=self._loop) + self._closed.set() if isinstance(timeout, timedelta): self.timeout = timeout.seconds @@ -54,32 +55,27 @@ class TcpClient: else: mode, address = socket.AF_INET, (ip, port) - timeout = 1 - while True: - try: - if not self._socket: - self._recreate_socket(mode) + try: + if not self._socket: + self._recreate_socket(mode) - await asyncio.wait_for( - self._loop.sock_connect(self._socket, address), - timeout=self.timeout, - loop=self._loop - ) - break # Successful connection, stop retrying to connect - except asyncio.TimeoutError as e: - raise TimeoutError() from e - except OSError as e: - self._logger.debug('Connect exception: %r' % e) - # ConnectionError + (errno.EBADF, errno.ENOTSOCK, errno.EINVAL) - if e.errno in CONN_RESET_ERRNOS: - self._socket = None - await asyncio.sleep(timeout) - timeout = min(timeout * 2, MAX_TIMEOUT) - else: - raise + await asyncio.wait_for( + self._loop.sock_connect(self._socket, address), + timeout=self.timeout, + loop=self._loop + ) + + self._closed.clear() + except asyncio.TimeoutError as e: + raise TimeoutError() from e + except OSError as e: + if e.errno in CONN_RESET_ERRNOS: + self._raise_connection_reset(e) + else: + raise def _get_connected(self): - return self._socket is not None and self._socket.fileno() >= 0 + return not self._closed.is_set() connected = property(fget=_get_connected) @@ -87,60 +83,59 @@ class TcpClient: """Closes the connection""" try: if self._socket is not None: - self._socket.shutdown(socket.SHUT_RDWR) + if self.connected: + self._socket.shutdown(socket.SHUT_RDWR) self._socket.close() except OSError: pass # Ignore ENOTCONN, EBADF, and any other error when closing finally: self._socket = None + self._closed.set() + + async def _wait_close(self, coro): + done, _ = await asyncio.wait( + [coro, self._closed.wait()], + timeout=self.timeout, + return_when=asyncio.FIRST_COMPLETED, + loop=self._loop + ) + if not self.connected: + raise ConnectionResetError('Socket has closed') + if not done: + raise TimeoutError() + return await done.pop() async def write(self, data): """Writes (sends) the specified bytes to the connected peer""" - if self._socket is None: - self._raise_connection_reset() - + if not self.connected: + raise ConnectionResetError('No connection') try: - await asyncio.wait_for( - self.sock_sendall(data), - timeout=self.timeout, - loop=self._loop - ) - except asyncio.TimeoutError as e: - raise TimeoutError() from e + await self._wait_close(self.sock_sendall(data)) except OSError as e: - self._logger.debug('Write exception: %r' % e) if e.errno in CONN_RESET_ERRNOS: - self._raise_connection_reset() + self._raise_connection_reset(e) else: raise async def read(self, size): - """Reads (receives) a whole block of 'size bytes + """Reads (receives) a whole block of size bytes from the connected peer. """ - with BufferedWriter(BytesIO(), buffer_size=size) as buffer: bytes_left = size while bytes_left != 0: + if not self.connected: + raise ConnectionResetError('No connection') try: - if self._socket is None: - self._raise_connection_reset() - partial = await asyncio.wait_for( - self.sock_recv(bytes_left), - timeout=self.timeout, - loop=self._loop - ) - except asyncio.TimeoutError as e: - raise TimeoutError() from e + partial = await self._wait_close(self.sock_recv(bytes_left)) except OSError as e: - self._logger.debug('Read exception: %r' % e) if e.errno in CONN_RESET_ERRNOS: - self._raise_connection_reset() + self._raise_connection_reset(e) else: raise if len(partial) == 0: - self._raise_connection_reset() + self._raise_connection_reset('No data on read') buffer.write(partial) bytes_left -= len(partial) @@ -149,9 +144,12 @@ class TcpClient: buffer.flush() return buffer.raw.getvalue() - def _raise_connection_reset(self): + def _raise_connection_reset(self, error): + description = error if isinstance(error, str) else str(error) + if isinstance(error, str): + error = Exception(error) self.close() # Connection reset -> flag as socket closed - raise ConnectionResetError('The server has closed the connection.') + raise ConnectionResetError(description) from error # due to new https://github.com/python/cpython/pull/4386 def sock_recv(self, n):