mirror of
https://github.com/LonamiWebs/Telethon.git
synced 2025-03-09 21:55:48 +03:00
Refactoring of TcpClient
This commit is contained in:
parent
653dd21259
commit
32bca4f1b8
|
@ -6,7 +6,6 @@ import logging
|
|||
from datetime import timedelta
|
||||
from io import BytesIO, BufferedWriter
|
||||
|
||||
MAX_TIMEOUT = 15 # in seconds
|
||||
CONN_RESET_ERRNOS = {
|
||||
errno.EBADF, errno.ENOTSOCK, errno.ENETUNREACH,
|
||||
errno.EINVAL, errno.ENOTCONN, errno.EHOSTUNREACH,
|
||||
|
@ -24,6 +23,8 @@ class TcpClient:
|
|||
self._socket = None
|
||||
self._loop = loop if loop else asyncio.get_event_loop()
|
||||
self._logger = logging.getLogger(__name__)
|
||||
self._closed = asyncio.Event(loop=self._loop)
|
||||
self._closed.set()
|
||||
|
||||
if isinstance(timeout, timedelta):
|
||||
self.timeout = timeout.seconds
|
||||
|
@ -54,32 +55,27 @@ class TcpClient:
|
|||
else:
|
||||
mode, address = socket.AF_INET, (ip, port)
|
||||
|
||||
timeout = 1
|
||||
while True:
|
||||
try:
|
||||
if not self._socket:
|
||||
self._recreate_socket(mode)
|
||||
try:
|
||||
if not self._socket:
|
||||
self._recreate_socket(mode)
|
||||
|
||||
await asyncio.wait_for(
|
||||
self._loop.sock_connect(self._socket, address),
|
||||
timeout=self.timeout,
|
||||
loop=self._loop
|
||||
)
|
||||
break # Successful connection, stop retrying to connect
|
||||
except asyncio.TimeoutError as e:
|
||||
raise TimeoutError() from e
|
||||
except OSError as e:
|
||||
self._logger.debug('Connect exception: %r' % e)
|
||||
# ConnectionError + (errno.EBADF, errno.ENOTSOCK, errno.EINVAL)
|
||||
if e.errno in CONN_RESET_ERRNOS:
|
||||
self._socket = None
|
||||
await asyncio.sleep(timeout)
|
||||
timeout = min(timeout * 2, MAX_TIMEOUT)
|
||||
else:
|
||||
raise
|
||||
await asyncio.wait_for(
|
||||
self._loop.sock_connect(self._socket, address),
|
||||
timeout=self.timeout,
|
||||
loop=self._loop
|
||||
)
|
||||
|
||||
self._closed.clear()
|
||||
except asyncio.TimeoutError as e:
|
||||
raise TimeoutError() from e
|
||||
except OSError as e:
|
||||
if e.errno in CONN_RESET_ERRNOS:
|
||||
self._raise_connection_reset(e)
|
||||
else:
|
||||
raise
|
||||
|
||||
def _get_connected(self):
|
||||
return self._socket is not None and self._socket.fileno() >= 0
|
||||
return not self._closed.is_set()
|
||||
|
||||
connected = property(fget=_get_connected)
|
||||
|
||||
|
@ -87,60 +83,59 @@ class TcpClient:
|
|||
"""Closes the connection"""
|
||||
try:
|
||||
if self._socket is not None:
|
||||
self._socket.shutdown(socket.SHUT_RDWR)
|
||||
if self.connected:
|
||||
self._socket.shutdown(socket.SHUT_RDWR)
|
||||
self._socket.close()
|
||||
except OSError:
|
||||
pass # Ignore ENOTCONN, EBADF, and any other error when closing
|
||||
finally:
|
||||
self._socket = None
|
||||
self._closed.set()
|
||||
|
||||
async def _wait_close(self, coro):
|
||||
done, _ = await asyncio.wait(
|
||||
[coro, self._closed.wait()],
|
||||
timeout=self.timeout,
|
||||
return_when=asyncio.FIRST_COMPLETED,
|
||||
loop=self._loop
|
||||
)
|
||||
if not self.connected:
|
||||
raise ConnectionResetError('Socket has closed')
|
||||
if not done:
|
||||
raise TimeoutError()
|
||||
return await done.pop()
|
||||
|
||||
async def write(self, data):
|
||||
"""Writes (sends) the specified bytes to the connected peer"""
|
||||
if self._socket is None:
|
||||
self._raise_connection_reset()
|
||||
|
||||
if not self.connected:
|
||||
raise ConnectionResetError('No connection')
|
||||
try:
|
||||
await asyncio.wait_for(
|
||||
self.sock_sendall(data),
|
||||
timeout=self.timeout,
|
||||
loop=self._loop
|
||||
)
|
||||
except asyncio.TimeoutError as e:
|
||||
raise TimeoutError() from e
|
||||
await self._wait_close(self.sock_sendall(data))
|
||||
except OSError as e:
|
||||
self._logger.debug('Write exception: %r' % e)
|
||||
if e.errno in CONN_RESET_ERRNOS:
|
||||
self._raise_connection_reset()
|
||||
self._raise_connection_reset(e)
|
||||
else:
|
||||
raise
|
||||
|
||||
async def read(self, size):
|
||||
"""Reads (receives) a whole block of 'size bytes
|
||||
"""Reads (receives) a whole block of size bytes
|
||||
from the connected peer.
|
||||
"""
|
||||
|
||||
with BufferedWriter(BytesIO(), buffer_size=size) as buffer:
|
||||
bytes_left = size
|
||||
while bytes_left != 0:
|
||||
if not self.connected:
|
||||
raise ConnectionResetError('No connection')
|
||||
try:
|
||||
if self._socket is None:
|
||||
self._raise_connection_reset()
|
||||
partial = await asyncio.wait_for(
|
||||
self.sock_recv(bytes_left),
|
||||
timeout=self.timeout,
|
||||
loop=self._loop
|
||||
)
|
||||
except asyncio.TimeoutError as e:
|
||||
raise TimeoutError() from e
|
||||
partial = await self._wait_close(self.sock_recv(bytes_left))
|
||||
except OSError as e:
|
||||
self._logger.debug('Read exception: %r' % e)
|
||||
if e.errno in CONN_RESET_ERRNOS:
|
||||
self._raise_connection_reset()
|
||||
self._raise_connection_reset(e)
|
||||
else:
|
||||
raise
|
||||
|
||||
if len(partial) == 0:
|
||||
self._raise_connection_reset()
|
||||
self._raise_connection_reset('No data on read')
|
||||
|
||||
buffer.write(partial)
|
||||
bytes_left -= len(partial)
|
||||
|
@ -149,9 +144,12 @@ class TcpClient:
|
|||
buffer.flush()
|
||||
return buffer.raw.getvalue()
|
||||
|
||||
def _raise_connection_reset(self):
|
||||
def _raise_connection_reset(self, error):
|
||||
description = error if isinstance(error, str) else str(error)
|
||||
if isinstance(error, str):
|
||||
error = Exception(error)
|
||||
self.close() # Connection reset -> flag as socket closed
|
||||
raise ConnectionResetError('The server has closed the connection.')
|
||||
raise ConnectionResetError(description) from error
|
||||
|
||||
# due to new https://github.com/python/cpython/pull/4386
|
||||
def sock_recv(self, n):
|
||||
|
|
Loading…
Reference in New Issue
Block a user