Refactoring of TcpClient

This commit is contained in:
Andrey Egorov 2017-11-19 01:55:40 +03:00
parent 653dd21259
commit 32bca4f1b8

View File

@ -6,7 +6,6 @@ import logging
from datetime import timedelta
from io import BytesIO, BufferedWriter
MAX_TIMEOUT = 15 # in seconds
CONN_RESET_ERRNOS = {
errno.EBADF, errno.ENOTSOCK, errno.ENETUNREACH,
errno.EINVAL, errno.ENOTCONN, errno.EHOSTUNREACH,
@ -24,6 +23,8 @@ class TcpClient:
self._socket = None
self._loop = loop if loop else asyncio.get_event_loop()
self._logger = logging.getLogger(__name__)
self._closed = asyncio.Event(loop=self._loop)
self._closed.set()
if isinstance(timeout, timedelta):
self.timeout = timeout.seconds
@ -54,32 +55,27 @@ class TcpClient:
else:
mode, address = socket.AF_INET, (ip, port)
timeout = 1
while True:
try:
if not self._socket:
self._recreate_socket(mode)
try:
if not self._socket:
self._recreate_socket(mode)
await asyncio.wait_for(
self._loop.sock_connect(self._socket, address),
timeout=self.timeout,
loop=self._loop
)
break # Successful connection, stop retrying to connect
except asyncio.TimeoutError as e:
raise TimeoutError() from e
except OSError as e:
self._logger.debug('Connect exception: %r' % e)
# ConnectionError + (errno.EBADF, errno.ENOTSOCK, errno.EINVAL)
if e.errno in CONN_RESET_ERRNOS:
self._socket = None
await asyncio.sleep(timeout)
timeout = min(timeout * 2, MAX_TIMEOUT)
else:
raise
await asyncio.wait_for(
self._loop.sock_connect(self._socket, address),
timeout=self.timeout,
loop=self._loop
)
self._closed.clear()
except asyncio.TimeoutError as e:
raise TimeoutError() from e
except OSError as e:
if e.errno in CONN_RESET_ERRNOS:
self._raise_connection_reset(e)
else:
raise
def _get_connected(self):
return self._socket is not None and self._socket.fileno() >= 0
return not self._closed.is_set()
connected = property(fget=_get_connected)
@ -87,60 +83,59 @@ class TcpClient:
"""Closes the connection"""
try:
if self._socket is not None:
self._socket.shutdown(socket.SHUT_RDWR)
if self.connected:
self._socket.shutdown(socket.SHUT_RDWR)
self._socket.close()
except OSError:
pass # Ignore ENOTCONN, EBADF, and any other error when closing
finally:
self._socket = None
self._closed.set()
async def _wait_close(self, coro):
done, _ = await asyncio.wait(
[coro, self._closed.wait()],
timeout=self.timeout,
return_when=asyncio.FIRST_COMPLETED,
loop=self._loop
)
if not self.connected:
raise ConnectionResetError('Socket has closed')
if not done:
raise TimeoutError()
return await done.pop()
async def write(self, data):
"""Writes (sends) the specified bytes to the connected peer"""
if self._socket is None:
self._raise_connection_reset()
if not self.connected:
raise ConnectionResetError('No connection')
try:
await asyncio.wait_for(
self.sock_sendall(data),
timeout=self.timeout,
loop=self._loop
)
except asyncio.TimeoutError as e:
raise TimeoutError() from e
await self._wait_close(self.sock_sendall(data))
except OSError as e:
self._logger.debug('Write exception: %r' % e)
if e.errno in CONN_RESET_ERRNOS:
self._raise_connection_reset()
self._raise_connection_reset(e)
else:
raise
async def read(self, size):
"""Reads (receives) a whole block of 'size bytes
"""Reads (receives) a whole block of size bytes
from the connected peer.
"""
with BufferedWriter(BytesIO(), buffer_size=size) as buffer:
bytes_left = size
while bytes_left != 0:
if not self.connected:
raise ConnectionResetError('No connection')
try:
if self._socket is None:
self._raise_connection_reset()
partial = await asyncio.wait_for(
self.sock_recv(bytes_left),
timeout=self.timeout,
loop=self._loop
)
except asyncio.TimeoutError as e:
raise TimeoutError() from e
partial = await self._wait_close(self.sock_recv(bytes_left))
except OSError as e:
self._logger.debug('Read exception: %r' % e)
if e.errno in CONN_RESET_ERRNOS:
self._raise_connection_reset()
self._raise_connection_reset(e)
else:
raise
if len(partial) == 0:
self._raise_connection_reset()
self._raise_connection_reset('No data on read')
buffer.write(partial)
bytes_left -= len(partial)
@ -149,9 +144,12 @@ class TcpClient:
buffer.flush()
return buffer.raw.getvalue()
def _raise_connection_reset(self):
def _raise_connection_reset(self, error):
description = error if isinstance(error, str) else str(error)
if isinstance(error, str):
error = Exception(error)
self.close() # Connection reset -> flag as socket closed
raise ConnectionResetError('The server has closed the connection.')
raise ConnectionResetError(description) from error
# due to new https://github.com/python/cpython/pull/4386
def sock_recv(self, n):