Refactoring of TcpClient

This commit is contained in:
Andrey Egorov 2017-11-19 01:55:40 +03:00
parent 653dd21259
commit 32bca4f1b8

View File

@ -6,7 +6,6 @@ import logging
from datetime import timedelta from datetime import timedelta
from io import BytesIO, BufferedWriter from io import BytesIO, BufferedWriter
MAX_TIMEOUT = 15 # in seconds
CONN_RESET_ERRNOS = { CONN_RESET_ERRNOS = {
errno.EBADF, errno.ENOTSOCK, errno.ENETUNREACH, errno.EBADF, errno.ENOTSOCK, errno.ENETUNREACH,
errno.EINVAL, errno.ENOTCONN, errno.EHOSTUNREACH, errno.EINVAL, errno.ENOTCONN, errno.EHOSTUNREACH,
@ -24,6 +23,8 @@ class TcpClient:
self._socket = None self._socket = None
self._loop = loop if loop else asyncio.get_event_loop() self._loop = loop if loop else asyncio.get_event_loop()
self._logger = logging.getLogger(__name__) self._logger = logging.getLogger(__name__)
self._closed = asyncio.Event(loop=self._loop)
self._closed.set()
if isinstance(timeout, timedelta): if isinstance(timeout, timedelta):
self.timeout = timeout.seconds self.timeout = timeout.seconds
@ -54,32 +55,27 @@ class TcpClient:
else: else:
mode, address = socket.AF_INET, (ip, port) mode, address = socket.AF_INET, (ip, port)
timeout = 1 try:
while True: if not self._socket:
try: self._recreate_socket(mode)
if not self._socket:
self._recreate_socket(mode)
await asyncio.wait_for( await asyncio.wait_for(
self._loop.sock_connect(self._socket, address), self._loop.sock_connect(self._socket, address),
timeout=self.timeout, timeout=self.timeout,
loop=self._loop loop=self._loop
) )
break # Successful connection, stop retrying to connect
except asyncio.TimeoutError as e: self._closed.clear()
raise TimeoutError() from e except asyncio.TimeoutError as e:
except OSError as e: raise TimeoutError() from e
self._logger.debug('Connect exception: %r' % e) except OSError as e:
# ConnectionError + (errno.EBADF, errno.ENOTSOCK, errno.EINVAL) if e.errno in CONN_RESET_ERRNOS:
if e.errno in CONN_RESET_ERRNOS: self._raise_connection_reset(e)
self._socket = None else:
await asyncio.sleep(timeout) raise
timeout = min(timeout * 2, MAX_TIMEOUT)
else:
raise
def _get_connected(self): def _get_connected(self):
return self._socket is not None and self._socket.fileno() >= 0 return not self._closed.is_set()
connected = property(fget=_get_connected) connected = property(fget=_get_connected)
@ -87,60 +83,59 @@ class TcpClient:
"""Closes the connection""" """Closes the connection"""
try: try:
if self._socket is not None: if self._socket is not None:
self._socket.shutdown(socket.SHUT_RDWR) if self.connected:
self._socket.shutdown(socket.SHUT_RDWR)
self._socket.close() self._socket.close()
except OSError: except OSError:
pass # Ignore ENOTCONN, EBADF, and any other error when closing pass # Ignore ENOTCONN, EBADF, and any other error when closing
finally: finally:
self._socket = None self._socket = None
self._closed.set()
async def _wait_close(self, coro):
done, _ = await asyncio.wait(
[coro, self._closed.wait()],
timeout=self.timeout,
return_when=asyncio.FIRST_COMPLETED,
loop=self._loop
)
if not self.connected:
raise ConnectionResetError('Socket has closed')
if not done:
raise TimeoutError()
return await done.pop()
async def write(self, data): async def write(self, data):
"""Writes (sends) the specified bytes to the connected peer""" """Writes (sends) the specified bytes to the connected peer"""
if self._socket is None: if not self.connected:
self._raise_connection_reset() raise ConnectionResetError('No connection')
try: try:
await asyncio.wait_for( await self._wait_close(self.sock_sendall(data))
self.sock_sendall(data),
timeout=self.timeout,
loop=self._loop
)
except asyncio.TimeoutError as e:
raise TimeoutError() from e
except OSError as e: except OSError as e:
self._logger.debug('Write exception: %r' % e)
if e.errno in CONN_RESET_ERRNOS: if e.errno in CONN_RESET_ERRNOS:
self._raise_connection_reset() self._raise_connection_reset(e)
else: else:
raise raise
async def read(self, size): async def read(self, size):
"""Reads (receives) a whole block of 'size bytes """Reads (receives) a whole block of size bytes
from the connected peer. from the connected peer.
""" """
with BufferedWriter(BytesIO(), buffer_size=size) as buffer: with BufferedWriter(BytesIO(), buffer_size=size) as buffer:
bytes_left = size bytes_left = size
while bytes_left != 0: while bytes_left != 0:
if not self.connected:
raise ConnectionResetError('No connection')
try: try:
if self._socket is None: partial = await self._wait_close(self.sock_recv(bytes_left))
self._raise_connection_reset()
partial = await asyncio.wait_for(
self.sock_recv(bytes_left),
timeout=self.timeout,
loop=self._loop
)
except asyncio.TimeoutError as e:
raise TimeoutError() from e
except OSError as e: except OSError as e:
self._logger.debug('Read exception: %r' % e)
if e.errno in CONN_RESET_ERRNOS: if e.errno in CONN_RESET_ERRNOS:
self._raise_connection_reset() self._raise_connection_reset(e)
else: else:
raise raise
if len(partial) == 0: if len(partial) == 0:
self._raise_connection_reset() self._raise_connection_reset('No data on read')
buffer.write(partial) buffer.write(partial)
bytes_left -= len(partial) bytes_left -= len(partial)
@ -149,9 +144,12 @@ class TcpClient:
buffer.flush() buffer.flush()
return buffer.raw.getvalue() return buffer.raw.getvalue()
def _raise_connection_reset(self): def _raise_connection_reset(self, error):
description = error if isinstance(error, str) else str(error)
if isinstance(error, str):
error = Exception(error)
self.close() # Connection reset -> flag as socket closed self.close() # Connection reset -> flag as socket closed
raise ConnectionResetError('The server has closed the connection.') raise ConnectionResetError(description) from error
# due to new https://github.com/python/cpython/pull/4386 # due to new https://github.com/python/cpython/pull/4386
def sock_recv(self, n): def sock_recv(self, n):