From 92c9bb12b70f28b11abaaa06dca76b13a40db0fe Mon Sep 17 00:00:00 2001 From: Lonami Exo Date: Wed, 29 Aug 2018 11:35:44 +0200 Subject: [PATCH] Use asyncio.open_connection in the TcpClient (cherry picked from commit 573fed1f512831cd9790130cc878655fef2fde98) --- telethon/extensions/tcpclient.py | 126 +++---------------------------- 1 file changed, 12 insertions(+), 114 deletions(-) diff --git a/telethon/extensions/tcpclient.py b/telethon/extensions/tcpclient.py index 1501d295..4baa46ca 100644 --- a/telethon/extensions/tcpclient.py +++ b/telethon/extensions/tcpclient.py @@ -12,7 +12,6 @@ import errno import logging import socket import ssl -from io import BytesIO CONN_RESET_ERRNOS = { errno.EBADF, errno.ENOTSOCK, errno.ENETUNREACH, @@ -52,6 +51,8 @@ class TcpClient: self.proxy = proxy self.ssl = ssl self._socket = None + self._reader = None + self._writer = None self._closed = asyncio.Event(loop=self._loop) self._closed.set() @@ -111,6 +112,8 @@ class TcpClient: self._socket.setblocking(False) self._closed.clear() + self._reader, self._writer =\ + await asyncio.open_connection(sock=self._socket) except OSError as e: if e.errno in CONN_RESET_ERRNOS: raise ConnectionResetError() from e @@ -126,6 +129,9 @@ class TcpClient: """Closes the connection.""" fd = None try: + if self._writer is not None: + self._writer.close() + if self._socket is not None: fd = self._socket.fileno() if self.is_connected: @@ -135,36 +141,12 @@ class TcpClient: pass # Ignore ENOTCONN, EBADF, and any other error when closing finally: self._socket = None + self._reader = None + self._writer = None self._closed.set() if fd and fd != -1: self._loop.remove_reader(fd) - async def _wait_timeout_or_close(self, coro): - """ - Waits for the given coroutine to complete unless - the socket is closed or `self.timeout` expires. - """ - done, running = await asyncio.wait( - [coro, self._closed.wait()], - timeout=self.timeout, - return_when=asyncio.FIRST_COMPLETED, - loop=self._loop - ) - for r in running: - if not r.cancelled(): - if r.done(): - # Retrieve exception to avoid "not retrieved" errors - r.exception() - - # Cancel the future despite its state - r.cancel() - - if not self.is_connected: - raise self.SocketClosed() - if not done: - raise asyncio.TimeoutError() - return done.pop().result() - async def write(self, data): """ Writes (sends) the specified bytes to the connected peer. @@ -173,13 +155,8 @@ class TcpClient: if not self.is_connected: raise ConnectionResetError('Not connected') - try: - await self._wait_timeout_or_close(self.sock_sendall(data)) - except OSError as e: - if e.errno in CONN_RESET_ERRNOS: - raise ConnectionResetError() from e - else: - raise + self._writer.write(data) + await self._writer.drain() async def read(self, size): """ @@ -191,83 +168,4 @@ class TcpClient: if not self.is_connected: raise ConnectionResetError('Not connected') - with BytesIO() as buffer: - bytes_left = size - while bytes_left != 0: - try: - partial = await self._wait_timeout_or_close( - self.sock_recv(bytes_left) - ) - except asyncio.TimeoutError: - if bytes_left < size: - __log__.warning( - 'Timeout when partial %d/%d had been received', - size - bytes_left, size - ) - raise - except OSError as e: - if e.errno in CONN_RESET_ERRNOS: - raise ConnectionResetError() from e - else: - raise - - if not partial: - raise ConnectionResetError() - - buffer.write(partial) - bytes_left -= len(partial) - - return buffer.getvalue() - - # Due to recent https://github.com/python/cpython/pull/4386 - # Credit to @andr-04 for his original implementation - def sock_recv(self, n): - fut = self._loop.create_future() - self._sock_recv(fut, None, n) - return fut - - def _sock_recv(self, fut, registered_fd, n): - if registered_fd is not None: - self._loop.remove_reader(registered_fd) - if fut.cancelled() or self._socket is None: - return - - try: - data = self._socket.recv(n) - except (BlockingIOError, InterruptedError): - fd = self._socket.fileno() - self._loop.add_reader(fd, self._sock_recv, fut, fd, n) - except Exception as exc: - fut.set_exception(exc) - else: - fut.set_result(data) - - def sock_sendall(self, data): - fut = self._loop.create_future() - if data: - self._sock_sendall(fut, None, data) - else: - fut.set_result(None) - return fut - - def _sock_sendall(self, fut, registered_fd, data): - if registered_fd: - self._loop.remove_writer(registered_fd) - if fut.cancelled() or self._socket is None: - return - - try: - n = self._socket.send(data) - except (BlockingIOError, InterruptedError): - n = 0 - except Exception as exc: - fut.set_exception(exc) - return - - if n == len(data): - fut.set_result(None) - else: - if n: - data = data[n:] - fd = self._socket.fileno() - self._loop.add_writer(fd, self._sock_sendall, fut, fd, data) + return await self._reader.readexactly(size)