mirror of
https://github.com/LonamiWebs/Telethon.git
synced 2025-03-10 06:05:47 +03:00
Refactoring of TcpClient
This commit is contained in:
parent
653dd21259
commit
32bca4f1b8
|
@ -6,7 +6,6 @@ import logging
|
||||||
from datetime import timedelta
|
from datetime import timedelta
|
||||||
from io import BytesIO, BufferedWriter
|
from io import BytesIO, BufferedWriter
|
||||||
|
|
||||||
MAX_TIMEOUT = 15 # in seconds
|
|
||||||
CONN_RESET_ERRNOS = {
|
CONN_RESET_ERRNOS = {
|
||||||
errno.EBADF, errno.ENOTSOCK, errno.ENETUNREACH,
|
errno.EBADF, errno.ENOTSOCK, errno.ENETUNREACH,
|
||||||
errno.EINVAL, errno.ENOTCONN, errno.EHOSTUNREACH,
|
errno.EINVAL, errno.ENOTCONN, errno.EHOSTUNREACH,
|
||||||
|
@ -24,6 +23,8 @@ class TcpClient:
|
||||||
self._socket = None
|
self._socket = None
|
||||||
self._loop = loop if loop else asyncio.get_event_loop()
|
self._loop = loop if loop else asyncio.get_event_loop()
|
||||||
self._logger = logging.getLogger(__name__)
|
self._logger = logging.getLogger(__name__)
|
||||||
|
self._closed = asyncio.Event(loop=self._loop)
|
||||||
|
self._closed.set()
|
||||||
|
|
||||||
if isinstance(timeout, timedelta):
|
if isinstance(timeout, timedelta):
|
||||||
self.timeout = timeout.seconds
|
self.timeout = timeout.seconds
|
||||||
|
@ -54,32 +55,27 @@ class TcpClient:
|
||||||
else:
|
else:
|
||||||
mode, address = socket.AF_INET, (ip, port)
|
mode, address = socket.AF_INET, (ip, port)
|
||||||
|
|
||||||
timeout = 1
|
try:
|
||||||
while True:
|
if not self._socket:
|
||||||
try:
|
self._recreate_socket(mode)
|
||||||
if not self._socket:
|
|
||||||
self._recreate_socket(mode)
|
|
||||||
|
|
||||||
await asyncio.wait_for(
|
await asyncio.wait_for(
|
||||||
self._loop.sock_connect(self._socket, address),
|
self._loop.sock_connect(self._socket, address),
|
||||||
timeout=self.timeout,
|
timeout=self.timeout,
|
||||||
loop=self._loop
|
loop=self._loop
|
||||||
)
|
)
|
||||||
break # Successful connection, stop retrying to connect
|
|
||||||
except asyncio.TimeoutError as e:
|
self._closed.clear()
|
||||||
raise TimeoutError() from e
|
except asyncio.TimeoutError as e:
|
||||||
except OSError as e:
|
raise TimeoutError() from e
|
||||||
self._logger.debug('Connect exception: %r' % e)
|
except OSError as e:
|
||||||
# ConnectionError + (errno.EBADF, errno.ENOTSOCK, errno.EINVAL)
|
if e.errno in CONN_RESET_ERRNOS:
|
||||||
if e.errno in CONN_RESET_ERRNOS:
|
self._raise_connection_reset(e)
|
||||||
self._socket = None
|
else:
|
||||||
await asyncio.sleep(timeout)
|
raise
|
||||||
timeout = min(timeout * 2, MAX_TIMEOUT)
|
|
||||||
else:
|
|
||||||
raise
|
|
||||||
|
|
||||||
def _get_connected(self):
|
def _get_connected(self):
|
||||||
return self._socket is not None and self._socket.fileno() >= 0
|
return not self._closed.is_set()
|
||||||
|
|
||||||
connected = property(fget=_get_connected)
|
connected = property(fget=_get_connected)
|
||||||
|
|
||||||
|
@ -87,60 +83,59 @@ class TcpClient:
|
||||||
"""Closes the connection"""
|
"""Closes the connection"""
|
||||||
try:
|
try:
|
||||||
if self._socket is not None:
|
if self._socket is not None:
|
||||||
self._socket.shutdown(socket.SHUT_RDWR)
|
if self.connected:
|
||||||
|
self._socket.shutdown(socket.SHUT_RDWR)
|
||||||
self._socket.close()
|
self._socket.close()
|
||||||
except OSError:
|
except OSError:
|
||||||
pass # Ignore ENOTCONN, EBADF, and any other error when closing
|
pass # Ignore ENOTCONN, EBADF, and any other error when closing
|
||||||
finally:
|
finally:
|
||||||
self._socket = None
|
self._socket = None
|
||||||
|
self._closed.set()
|
||||||
|
|
||||||
|
async def _wait_close(self, coro):
|
||||||
|
done, _ = await asyncio.wait(
|
||||||
|
[coro, self._closed.wait()],
|
||||||
|
timeout=self.timeout,
|
||||||
|
return_when=asyncio.FIRST_COMPLETED,
|
||||||
|
loop=self._loop
|
||||||
|
)
|
||||||
|
if not self.connected:
|
||||||
|
raise ConnectionResetError('Socket has closed')
|
||||||
|
if not done:
|
||||||
|
raise TimeoutError()
|
||||||
|
return await done.pop()
|
||||||
|
|
||||||
async def write(self, data):
|
async def write(self, data):
|
||||||
"""Writes (sends) the specified bytes to the connected peer"""
|
"""Writes (sends) the specified bytes to the connected peer"""
|
||||||
if self._socket is None:
|
if not self.connected:
|
||||||
self._raise_connection_reset()
|
raise ConnectionResetError('No connection')
|
||||||
|
|
||||||
try:
|
try:
|
||||||
await asyncio.wait_for(
|
await self._wait_close(self.sock_sendall(data))
|
||||||
self.sock_sendall(data),
|
|
||||||
timeout=self.timeout,
|
|
||||||
loop=self._loop
|
|
||||||
)
|
|
||||||
except asyncio.TimeoutError as e:
|
|
||||||
raise TimeoutError() from e
|
|
||||||
except OSError as e:
|
except OSError as e:
|
||||||
self._logger.debug('Write exception: %r' % e)
|
|
||||||
if e.errno in CONN_RESET_ERRNOS:
|
if e.errno in CONN_RESET_ERRNOS:
|
||||||
self._raise_connection_reset()
|
self._raise_connection_reset(e)
|
||||||
else:
|
else:
|
||||||
raise
|
raise
|
||||||
|
|
||||||
async def read(self, size):
|
async def read(self, size):
|
||||||
"""Reads (receives) a whole block of 'size bytes
|
"""Reads (receives) a whole block of size bytes
|
||||||
from the connected peer.
|
from the connected peer.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
with BufferedWriter(BytesIO(), buffer_size=size) as buffer:
|
with BufferedWriter(BytesIO(), buffer_size=size) as buffer:
|
||||||
bytes_left = size
|
bytes_left = size
|
||||||
while bytes_left != 0:
|
while bytes_left != 0:
|
||||||
|
if not self.connected:
|
||||||
|
raise ConnectionResetError('No connection')
|
||||||
try:
|
try:
|
||||||
if self._socket is None:
|
partial = await self._wait_close(self.sock_recv(bytes_left))
|
||||||
self._raise_connection_reset()
|
|
||||||
partial = await asyncio.wait_for(
|
|
||||||
self.sock_recv(bytes_left),
|
|
||||||
timeout=self.timeout,
|
|
||||||
loop=self._loop
|
|
||||||
)
|
|
||||||
except asyncio.TimeoutError as e:
|
|
||||||
raise TimeoutError() from e
|
|
||||||
except OSError as e:
|
except OSError as e:
|
||||||
self._logger.debug('Read exception: %r' % e)
|
|
||||||
if e.errno in CONN_RESET_ERRNOS:
|
if e.errno in CONN_RESET_ERRNOS:
|
||||||
self._raise_connection_reset()
|
self._raise_connection_reset(e)
|
||||||
else:
|
else:
|
||||||
raise
|
raise
|
||||||
|
|
||||||
if len(partial) == 0:
|
if len(partial) == 0:
|
||||||
self._raise_connection_reset()
|
self._raise_connection_reset('No data on read')
|
||||||
|
|
||||||
buffer.write(partial)
|
buffer.write(partial)
|
||||||
bytes_left -= len(partial)
|
bytes_left -= len(partial)
|
||||||
|
@ -149,9 +144,12 @@ class TcpClient:
|
||||||
buffer.flush()
|
buffer.flush()
|
||||||
return buffer.raw.getvalue()
|
return buffer.raw.getvalue()
|
||||||
|
|
||||||
def _raise_connection_reset(self):
|
def _raise_connection_reset(self, error):
|
||||||
|
description = error if isinstance(error, str) else str(error)
|
||||||
|
if isinstance(error, str):
|
||||||
|
error = Exception(error)
|
||||||
self.close() # Connection reset -> flag as socket closed
|
self.close() # Connection reset -> flag as socket closed
|
||||||
raise ConnectionResetError('The server has closed the connection.')
|
raise ConnectionResetError(description) from error
|
||||||
|
|
||||||
# due to new https://github.com/python/cpython/pull/4386
|
# due to new https://github.com/python/cpython/pull/4386
|
||||||
def sock_recv(self, n):
|
def sock_recv(self, n):
|
||||||
|
|
Loading…
Reference in New Issue
Block a user