mirror of
				https://github.com/LonamiWebs/Telethon.git
				synced 2025-11-04 09:57:29 +03:00 
			
		
		
		
	Apply @andr-04 asyncio commits to TcpClient
This commit is contained in:
		
							parent
							
								
									3ce8b17193
								
							
						
					
					
						commit
						c9ea1bafc0
					
				| 
						 | 
					@ -8,22 +8,37 @@ This class is also not concerned about disconnections or retries of
 | 
				
			||||||
any sort, nor any other kind of errors such as connecting twice.
 | 
					any sort, nor any other kind of errors such as connecting twice.
 | 
				
			||||||
"""
 | 
					"""
 | 
				
			||||||
import asyncio
 | 
					import asyncio
 | 
				
			||||||
 | 
					import errno
 | 
				
			||||||
import logging
 | 
					import logging
 | 
				
			||||||
import socket
 | 
					import socket
 | 
				
			||||||
 | 
					from datetime import timedelta
 | 
				
			||||||
from io import BytesIO
 | 
					from io import BytesIO
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					CONN_RESET_ERRNOS = {
 | 
				
			||||||
 | 
					    errno.EBADF, errno.ENOTSOCK, errno.ENETUNREACH,
 | 
				
			||||||
 | 
					    errno.EINVAL, errno.ENOTCONN, errno.EHOSTUNREACH,
 | 
				
			||||||
 | 
					    errno.ECONNREFUSED, errno.ECONNRESET, errno.ECONNABORTED,
 | 
				
			||||||
 | 
					    errno.ENETDOWN, errno.ENETRESET, errno.ECONNABORTED,
 | 
				
			||||||
 | 
					    errno.EHOSTDOWN, errno.EPIPE, errno.ESHUTDOWN
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					# catched: EHOSTUNREACH, ECONNREFUSED, ECONNRESET, ENETUNREACH
 | 
				
			||||||
 | 
					# ConnectionError: EPIPE, ESHUTDOWN, ECONNABORTED, ECONNREFUSED, ECONNRESET
 | 
				
			||||||
 | 
					
 | 
				
			||||||
try:
 | 
					try:
 | 
				
			||||||
    import socks
 | 
					    import socks
 | 
				
			||||||
except ImportError:
 | 
					except ImportError:
 | 
				
			||||||
    socks = None
 | 
					    socks = None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					 | 
				
			||||||
__log__ = logging.getLogger(__name__)
 | 
					__log__ = logging.getLogger(__name__)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class TcpClient:
 | 
					class TcpClient:
 | 
				
			||||||
    """A simple TCP client to ease the work with sockets and proxies."""
 | 
					    """A simple TCP client to ease the work with sockets and proxies."""
 | 
				
			||||||
    def __init__(self, proxy=None, timeout=5):
 | 
					
 | 
				
			||||||
 | 
					    class SocketClosed(ConnectionError):
 | 
				
			||||||
 | 
					        pass
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __init__(self, proxy=None, timeout=timedelta(seconds=5), loop=None):
 | 
				
			||||||
        """
 | 
					        """
 | 
				
			||||||
        Initializes the TCP client.
 | 
					        Initializes the TCP client.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -32,7 +47,9 @@ class TcpClient:
 | 
				
			||||||
        """
 | 
					        """
 | 
				
			||||||
        self.proxy = proxy
 | 
					        self.proxy = proxy
 | 
				
			||||||
        self._socket = None
 | 
					        self._socket = None
 | 
				
			||||||
        self._loop = asyncio.get_event_loop()
 | 
					        self._loop = loop or asyncio.get_event_loop()
 | 
				
			||||||
 | 
					        self._closed = asyncio.Event(loop=self._loop)
 | 
				
			||||||
 | 
					        self._closed.set()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        if isinstance(timeout, (int, float)):
 | 
					        if isinstance(timeout, (int, float)):
 | 
				
			||||||
            self.timeout = float(timeout)
 | 
					            self.timeout = float(timeout)
 | 
				
			||||||
| 
						 | 
					@ -57,7 +74,7 @@ class TcpClient:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    async def connect(self, ip, port):
 | 
					    async def connect(self, ip, port):
 | 
				
			||||||
        """
 | 
					        """
 | 
				
			||||||
        Tries connecting to IP:port.
 | 
					        Tries connecting to IP:port unless an OSError is raised.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        :param ip: the IP to connect to.
 | 
					        :param ip: the IP to connect to.
 | 
				
			||||||
        :param port: the port to connect to.
 | 
					        :param port: the port to connect to.
 | 
				
			||||||
| 
						 | 
					@ -68,42 +85,78 @@ class TcpClient:
 | 
				
			||||||
        else:
 | 
					        else:
 | 
				
			||||||
            mode, address = socket.AF_INET, (ip, port)
 | 
					            mode, address = socket.AF_INET, (ip, port)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        try:
 | 
				
			||||||
            if self._socket is None:
 | 
					            if self._socket is None:
 | 
				
			||||||
                self._socket = self._create_socket(mode, self.proxy)
 | 
					                self._socket = self._create_socket(mode, self.proxy)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        await asyncio.wait_for(self._loop.sock_connect(self._socket, address),
 | 
					            await asyncio.wait_for(
 | 
				
			||||||
                               self.timeout, loop=self._loop)
 | 
					                self._loop.sock_connect(self._socket, address),
 | 
				
			||||||
 | 
					                timeout=self.timeout,
 | 
				
			||||||
 | 
					                loop=self._loop
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					            self._closed.clear()
 | 
				
			||||||
 | 
					        except asyncio.TimeoutError as e:
 | 
				
			||||||
 | 
					            raise TimeoutError() from e
 | 
				
			||||||
 | 
					        except OSError as e:
 | 
				
			||||||
 | 
					            if e.errno in CONN_RESET_ERRNOS:
 | 
				
			||||||
 | 
					                raise ConnectionResetError() from e
 | 
				
			||||||
 | 
					            else:
 | 
				
			||||||
 | 
					                raise
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @property
 | 
					    @property
 | 
				
			||||||
    def is_connected(self):
 | 
					    def is_connected(self):
 | 
				
			||||||
        """Determines whether the client is connected or not."""
 | 
					        """Determines whether the client is connected or not."""
 | 
				
			||||||
        # TODO fileno() is >= 0 even before calling sock_connect!
 | 
					        return not self._closed.is_set()
 | 
				
			||||||
        return self._socket is not None and self._socket.fileno() >= 0
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def close(self):
 | 
					    def close(self):
 | 
				
			||||||
        """Closes the connection."""
 | 
					        """Closes the connection."""
 | 
				
			||||||
        if self._socket is not None:
 | 
					 | 
				
			||||||
        try:
 | 
					        try:
 | 
				
			||||||
 | 
					            if self._socket is not None:
 | 
				
			||||||
 | 
					                if self.is_connected:
 | 
				
			||||||
 | 
					                    self._socket.shutdown(socket.SHUT_RDWR)
 | 
				
			||||||
                self._socket.close()
 | 
					                self._socket.close()
 | 
				
			||||||
        except OSError:
 | 
					        except OSError:
 | 
				
			||||||
                pass
 | 
					            pass  # Ignore ENOTCONN, EBADF, and any other error when closing
 | 
				
			||||||
        finally:
 | 
					        finally:
 | 
				
			||||||
            self._socket = None
 | 
					            self._socket = None
 | 
				
			||||||
 | 
					            self._closed.set()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    async def _wait_timeout_or_close(self, coro):
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        Waits for the given coroutine to complete unless
 | 
				
			||||||
 | 
					        the socket is closed or `self.timeout` expires.
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        done, running = await asyncio.wait(
 | 
				
			||||||
 | 
					            [coro, self._closed.wait()],
 | 
				
			||||||
 | 
					            timeout=self.timeout,
 | 
				
			||||||
 | 
					            return_when=asyncio.FIRST_COMPLETED,
 | 
				
			||||||
 | 
					            loop=self._loop
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        for r in running:
 | 
				
			||||||
 | 
					            r.cancel()
 | 
				
			||||||
 | 
					        if not self.is_connected:
 | 
				
			||||||
 | 
					            raise self.SocketClosed()
 | 
				
			||||||
 | 
					        if not done:
 | 
				
			||||||
 | 
					            raise TimeoutError()
 | 
				
			||||||
 | 
					        return done.pop().result()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    async def write(self, data):
 | 
					    async def write(self, data):
 | 
				
			||||||
        """
 | 
					        """
 | 
				
			||||||
        Writes (sends) the specified bytes to the connected peer.
 | 
					        Writes (sends) the specified bytes to the connected peer.
 | 
				
			||||||
 | 
					 | 
				
			||||||
        :param data: the data to send.
 | 
					        :param data: the data to send.
 | 
				
			||||||
        """
 | 
					        """
 | 
				
			||||||
        if not self.is_connected:
 | 
					        if not self.is_connected:
 | 
				
			||||||
            raise ConnectionError()
 | 
					            raise ConnectionResetError('Not connected')
 | 
				
			||||||
 | 
					        try:
 | 
				
			||||||
        await asyncio.wait_for(
 | 
					            await self._wait_timeout_or_close(self.sock_sendall(data))
 | 
				
			||||||
            self.sock_sendall(data),
 | 
					        except self.SocketClosed:
 | 
				
			||||||
            timeout=self.timeout,
 | 
					            raise ConnectionResetError('Socket has closed')
 | 
				
			||||||
            loop=self._loop
 | 
					        except OSError as e:
 | 
				
			||||||
        )
 | 
					            __log__.info('OSError "%s" while writing data', e)
 | 
				
			||||||
 | 
					            if e.errno in CONN_RESET_ERRNOS:
 | 
				
			||||||
 | 
					                raise ConnectionResetError() from e
 | 
				
			||||||
 | 
					            else:
 | 
				
			||||||
 | 
					                raise
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    async def read(self, size):
 | 
					    async def read(self, size):
 | 
				
			||||||
        """
 | 
					        """
 | 
				
			||||||
| 
						 | 
					@ -113,16 +166,32 @@ class TcpClient:
 | 
				
			||||||
        :return: the read data with len(data) == size.
 | 
					        :return: the read data with len(data) == size.
 | 
				
			||||||
        """
 | 
					        """
 | 
				
			||||||
        if not self.is_connected:
 | 
					        if not self.is_connected:
 | 
				
			||||||
            raise ConnectionError()
 | 
					            raise ConnectionResetError('Not connected')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        with BytesIO() as buffer:
 | 
					        with BytesIO() as buffer:
 | 
				
			||||||
            bytes_left = size
 | 
					            bytes_left = size
 | 
				
			||||||
            while bytes_left != 0:
 | 
					            while bytes_left != 0:
 | 
				
			||||||
                partial = await asyncio.wait_for(
 | 
					                try:
 | 
				
			||||||
                    self.sock_recv(bytes_left),
 | 
					                    partial = await self._wait_timeout_or_close(
 | 
				
			||||||
                    timeout=self.timeout,
 | 
					                        self.sock_recv(bytes_left)
 | 
				
			||||||
                    loop=self._loop
 | 
					 | 
				
			||||||
                    )
 | 
					                    )
 | 
				
			||||||
 | 
					                except TimeoutError as e:
 | 
				
			||||||
 | 
					                    if bytes_left < size:
 | 
				
			||||||
 | 
					                        __log__.warning(
 | 
				
			||||||
 | 
					                            'socket timeout "%s" when %d/%d had been received',
 | 
				
			||||||
 | 
					                            e, size - bytes_left, size
 | 
				
			||||||
 | 
					                        )
 | 
				
			||||||
 | 
					                    raise
 | 
				
			||||||
 | 
					                except self.SocketClosed:
 | 
				
			||||||
 | 
					                    raise ConnectionResetError(
 | 
				
			||||||
 | 
					                        'Socket has closed while reading data'
 | 
				
			||||||
 | 
					                    )
 | 
				
			||||||
 | 
					                except OSError as e:
 | 
				
			||||||
 | 
					                    if e.errno in CONN_RESET_ERRNOS:
 | 
				
			||||||
 | 
					                        raise ConnectionResetError() from e
 | 
				
			||||||
 | 
					                    else:
 | 
				
			||||||
 | 
					                        raise
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                if not partial:
 | 
					                if not partial:
 | 
				
			||||||
                    raise ConnectionResetError()
 | 
					                    raise ConnectionResetError()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -141,7 +210,7 @@ class TcpClient:
 | 
				
			||||||
    def _sock_recv(self, fut, registered_fd, n):
 | 
					    def _sock_recv(self, fut, registered_fd, n):
 | 
				
			||||||
        if registered_fd is not None:
 | 
					        if registered_fd is not None:
 | 
				
			||||||
            self._loop.remove_reader(registered_fd)
 | 
					            self._loop.remove_reader(registered_fd)
 | 
				
			||||||
        if fut.cancelled():
 | 
					        if fut.cancelled() or self._socket is None:
 | 
				
			||||||
            return
 | 
					            return
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        try:
 | 
					        try:
 | 
				
			||||||
| 
						 | 
					@ -165,7 +234,7 @@ class TcpClient:
 | 
				
			||||||
    def _sock_sendall(self, fut, registered_fd, data):
 | 
					    def _sock_sendall(self, fut, registered_fd, data):
 | 
				
			||||||
        if registered_fd:
 | 
					        if registered_fd:
 | 
				
			||||||
            self._loop.remove_writer(registered_fd)
 | 
					            self._loop.remove_writer(registered_fd)
 | 
				
			||||||
        if fut.cancelled():
 | 
					        if fut.cancelled() or self._socket is None:
 | 
				
			||||||
            return
 | 
					            return
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        try:
 | 
					        try:
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue
	
	Block a user