Turn timeout into a property instead leaving it as a parameter

This commit is contained in:
Lonami Exo 2017-06-22 11:43:42 +02:00
parent 52a42661ee
commit e4fbd87c75
5 changed files with 97 additions and 67 deletions

View File

@ -80,7 +80,7 @@ class TcpClient:
# Set the starting time so we can # Set the starting time so we can
# calculate whether the timeout should fire # calculate whether the timeout should fire
start_time = datetime.now() if timeout else None start_time = datetime.now() if timeout is not None else None
with BufferedWriter(BytesIO(), buffer_size=size) as buffer: with BufferedWriter(BytesIO(), buffer_size=size) as buffer:
bytes_left = size bytes_left = size
@ -104,7 +104,7 @@ class TcpClient:
time.sleep(self.delay) time.sleep(self.delay)
# Check if the timeout finished # Check if the timeout finished
if timeout: if timeout is not None:
time_passed = datetime.now() - start_time time_passed = datetime.now() - start_time
if time_passed > timeout: if time_passed > timeout:
raise TimeoutError( raise TimeoutError(

View File

@ -17,7 +17,7 @@ class MtProtoSender:
"""MTProto Mobile Protocol sender (https://core.telegram.org/mtproto/description)""" """MTProto Mobile Protocol sender (https://core.telegram.org/mtproto/description)"""
def __init__(self, transport, session): def __init__(self, transport, session):
self._transport = transport self.transport = transport
self.session = session self.session = session
self._logger = logging.getLogger(__name__) self._logger = logging.getLogger(__name__)
@ -33,14 +33,14 @@ class MtProtoSender:
def connect(self): def connect(self):
"""Connects to the server""" """Connects to the server"""
self._transport.connect() self.transport.connect()
def is_connected(self): def is_connected(self):
return self._transport.is_connected() return self.transport.is_connected()
def disconnect(self): def disconnect(self):
"""Disconnects from the server""" """Disconnects from the server"""
self._transport.close() self.transport.close()
# region Send and receive # region Send and receive
@ -76,11 +76,12 @@ class MtProtoSender:
del self._need_confirmation[:] del self._need_confirmation[:]
def receive(self, request=None, timeout=timedelta(seconds=5), updates=None): def receive(self, request=None, updates=None, **kwargs):
"""Receives the specified MTProtoRequest ("fills in it" """Receives the specified MTProtoRequest ("fills in it"
the received data). This also restores the updates thread. the received data). This also restores the updates thread.
An optional timeout can be specified to cancel the operation
if no data has been read after its time delta. An optional named parameter 'timeout' can be specified if
one desires to override 'self.transport.timeout'.
If 'request' is None, a single item will be read into If 'request' is None, a single item will be read into
the 'updates' list (which cannot be None). the 'updates' list (which cannot be None).
@ -100,7 +101,7 @@ class MtProtoSender:
while (request and not request.confirm_received) or \ while (request and not request.confirm_received) or \
(not request and not updates): (not request and not updates):
self._logger.info('Trying to .receive() the request result...') self._logger.info('Trying to .receive() the request result...')
seq, body = self._transport.receive(timeout) seq, body = self.transport.receive(**kwargs)
message, remote_msg_id, remote_seq = self._decode_msg(body) message, remote_msg_id, remote_seq = self._decode_msg(body)
with BinaryReader(message) as reader: with BinaryReader(message) as reader:
@ -116,18 +117,16 @@ class MtProtoSender:
self._logger.info('Request result received') self._logger.info('Request result received')
self._logger.debug('receive() released the lock') self._logger.debug('receive() released the lock')
def receive_updates(self, timeout=timedelta(seconds=5)): def receive_updates(self, **kwargs):
"""Receives one or more update objects """Wrapper for .receive(request=None, updates=[])"""
and returns them as a list
"""
updates = [] updates = []
self.receive(timeout=timeout, updates=updates) self.receive(updates=updates, **kwargs)
return updates return updates
def cancel_receive(self): def cancel_receive(self):
"""Cancels any pending receive operation """Cancels any pending receive operation
by raising a ReadCancelledError""" by raising a ReadCancelledError"""
self._transport.cancel_receive() self.transport.cancel_receive()
# endregion # endregion
@ -160,7 +159,7 @@ class MtProtoSender:
self.session.auth_key.key_id, signed=False) self.session.auth_key.key_id, signed=False)
cipher_writer.write(msg_key) cipher_writer.write(msg_key)
cipher_writer.write(cipher_text) cipher_writer.write(cipher_text)
self._transport.send(cipher_writer.get_bytes()) self.transport.send(cipher_writer.get_bytes())
def _decode_msg(self, body): def _decode_msg(self, body):
"""Decodes an received encrypted message body bytes""" """Decodes an received encrypted message body bytes"""

View File

@ -7,10 +7,12 @@ from ..extensions import BinaryWriter
class TcpTransport: class TcpTransport:
def __init__(self, ip_address, port, proxy=None): def __init__(self, ip_address, port,
proxy=None, timeout=timedelta(seconds=5)):
self.ip = ip_address self.ip = ip_address
self.port = port self.port = port
self.tcp_client = TcpClient(proxy) self.tcp_client = TcpClient(proxy)
self.timeout = timeout
self.send_counter = 0 self.send_counter = 0
def connect(self): def connect(self):
@ -22,7 +24,8 @@ class TcpTransport:
return self.tcp_client.connected return self.tcp_client.connected
# Original reference: https://core.telegram.org/mtproto#tcp-transport # Original reference: https://core.telegram.org/mtproto#tcp-transport
# The packets are encoded as: total length, sequence number, packet and checksum (CRC32) # The packets are encoded as:
# total length, sequence number, packet and checksum (CRC32)
def send(self, packet): def send(self, packet):
"""Sends the given packet (bytes array) to the connected peer""" """Sends the given packet (bytes array) to the connected peer"""
if not self.tcp_client.connected: if not self.tcp_client.connected:
@ -39,10 +42,14 @@ class TcpTransport:
self.send_counter += 1 self.send_counter += 1
self.tcp_client.write(writer.get_bytes()) self.tcp_client.write(writer.get_bytes())
def receive(self, timeout=timedelta(seconds=5)): def receive(self, **kwargs):
"""Receives a TCP message (tuple(sequence number, body)) from the connected peer. """Receives a TCP message (tuple(sequence number, body)) from the
There is a default timeout of 5 seconds before the operation is cancelled. connected peer.
Timeout can be set to None for no timeout"""
If a named 'timeout' parameter is present, it will override
'self.timeout', and this can be a 'timedelta' or 'None'.
"""
timeout = kwargs.get('timeout', self.timeout)
# First read everything we need # First read everything we need
packet_length_bytes = self.tcp_client.read(4, timeout) packet_length_bytes = self.tcp_client.read(4, timeout)

View File

@ -51,7 +51,8 @@ class TelegramBareClient:
# region Initialization # region Initialization
def __init__(self, session, api_id, api_hash, proxy=None): def __init__(self, session, api_id, api_hash,
proxy=None, timeout=timedelta(seconds=5)):
"""Initializes the Telegram client with the specified API ID and Hash. """Initializes the Telegram client with the specified API ID and Hash.
Session must always be a Session instance, and an optional proxy Session must always be a Session instance, and an optional proxy
can also be specified to be used on the connection. can also be specified to be used on the connection.
@ -60,35 +61,36 @@ class TelegramBareClient:
self.api_id = int(api_id) self.api_id = int(api_id)
self.api_hash = api_hash self.api_hash = api_hash
self.proxy = proxy self.proxy = proxy
self._timeout = timeout
self._logger = logging.getLogger(__name__) self._logger = logging.getLogger(__name__)
# These will be set later # These will be set later
self.dc_options = None self.dc_options = None
self.sender = None self._sender = None
# endregion # endregion
# region Connecting # region Connecting
def connect(self, timeout=timedelta(seconds=5), exported_auth=None): def connect(self, exported_auth=None):
"""Connects to the Telegram servers, executing authentication if """Connects to the Telegram servers, executing authentication if
required. Note that authenticating to the Telegram servers is required. Note that authenticating to the Telegram servers is
not the same as authenticating the desired user itself, which not the same as authenticating the desired user itself, which
may require a call (or several) to 'sign_in' for the first time. may require a call (or several) to 'sign_in' for the first time.
The specified timeout will be used on internal .invoke()'s.
If 'exported_auth' is not None, it will be used instead to If 'exported_auth' is not None, it will be used instead to
determine the authorization key for the current session. determine the authorization key for the current session.
""" """
if self.sender and self.sender.is_connected(): if self._sender and self._sender.is_connected():
self._logger.warning( self._logger.warning(
'Attempted to connect when the client was already connected.' 'Attempted to connect when the client was already connected.'
) )
return return
transport = TcpTransport(self.session.server_address, transport = TcpTransport(self.session.server_address,
self.session.port, proxy=self.proxy) self.session.port,
proxy=self.proxy,
timeout=self._timeout)
try: try:
if not self.session.auth_key: if not self.session.auth_key:
@ -97,8 +99,8 @@ class TelegramBareClient:
self.session.save() self.session.save()
self.sender = MtProtoSender(transport, self.session) self._sender = MtProtoSender(transport, self.session)
self.sender.connect() self._sender.connect()
# Now it's time to send an InitConnectionRequest # Now it's time to send an InitConnectionRequest
# This must always be invoked with the layer we'll be using # This must always be invoked with the layer we'll be using
@ -117,14 +119,13 @@ class TelegramBareClient:
query=query) query=query)
result = self.invoke( result = self.invoke(
InvokeWithLayerRequest(layer=layer, query=request), InvokeWithLayerRequest(layer=layer, query=request)
timeout=timeout
) )
if exported_auth is not None: if exported_auth is not None:
# TODO Don't actually need this for exported authorizations, # TODO Don't actually need this for exported authorizations,
# they're only valid on such data center. # they're only valid on such data center.
result = self.invoke(GetConfigRequest(), timeout=timeout) result = self.invoke(GetConfigRequest())
# We're only interested in the DC options, # We're only interested in the DC options,
# although many other options are available! # although many other options are available!
@ -140,9 +141,9 @@ class TelegramBareClient:
def disconnect(self): def disconnect(self):
"""Disconnects from the Telegram server""" """Disconnects from the Telegram server"""
if self.sender: if self._sender:
self.sender.disconnect() self._sender.disconnect()
self.sender = None self._sender = None
def reconnect(self, new_dc=None): def reconnect(self, new_dc=None):
"""Disconnects and connects again (effectively reconnecting). """Disconnects and connects again (effectively reconnecting).
@ -163,6 +164,30 @@ class TelegramBareClient:
# endregion # endregion
# region Properties
def set_timeout(self, timeout):
if timeout is None:
self._timeout = None
elif isinstance(timeout, int) or isinstance(timeout, float):
self._timeout = timedelta(seconds=timeout)
elif isinstance(timeout, timedelta):
self._timeout = timeout
else:
raise ValueError(
'{} is not a valid type for a timeout'.format(type(timeout))
)
if self._sender:
self._sender.transport.timeout = self._timeout
def get_timeout(self):
return self._timeout
timeout = property(get_timeout, set_timeout)
# endregion
# region Working with different Data Centers # region Working with different Data Centers
def _get_dc(self, dc_id): def _get_dc(self, dc_id):
@ -178,31 +203,28 @@ class TelegramBareClient:
# region Invoking Telegram requests # region Invoking Telegram requests
def invoke(self, request, timeout=timedelta(seconds=5), updates=None): def invoke(self, request, updates=None):
"""Invokes (sends) a MTProtoRequest and returns (receives) its result. """Invokes (sends) a MTProtoRequest and returns (receives) its result.
An optional timeout can be specified to cancel the operation if no
result is received within such time, or None to disable any timeout.
If 'updates' is not None, all read update object will be put If 'updates' is not None, all read update object will be put
in such list. Otherwise, update objects will be ignored. in such list. Otherwise, update objects will be ignored.
""" """
if not isinstance(request, MTProtoRequest): if not isinstance(request, MTProtoRequest):
raise ValueError('You can only invoke MtProtoRequests') raise ValueError('You can only invoke MtProtoRequests')
if not self.sender: if not self._sender:
raise ValueError('You must be connected to invoke requests!') raise ValueError('You must be connected to invoke requests!')
try: try:
self.sender.send(request) self._sender.send(request)
self.sender.receive(request, timeout, updates=updates) self._sender.receive(request, updates=updates)
return request.result return request.result
except ConnectionResetError: except ConnectionResetError:
self._logger.info('Server disconnected us. Reconnecting and ' self._logger.info('Server disconnected us. Reconnecting and '
'resending request...') 'resending request...')
self.reconnect() self.reconnect()
return self.invoke(request, timeout=timeout) return self.invoke(request)
except FloodWaitError: except FloodWaitError:
self.disconnect() self.disconnect()

View File

@ -64,7 +64,8 @@ class TelegramClient(TelegramBareClient):
def __init__(self, session, api_id, api_hash, proxy=None, def __init__(self, session, api_id, api_hash, proxy=None,
device_model=None, system_version=None, device_model=None, system_version=None,
app_version=None, lang_code=None): app_version=None, lang_code=None,
timeout=timedelta(seconds=5)):
"""Initializes the Telegram client with the specified API ID and Hash. """Initializes the Telegram client with the specified API ID and Hash.
Session can either be a `str` object (filename for the .session) Session can either be a `str` object (filename for the .session)
@ -78,7 +79,6 @@ class TelegramClient(TelegramBareClient):
app_version = TelegramClient.__version__ app_version = TelegramClient.__version__
lang_code = 'en' lang_code = 'en'
""" """
if not api_id or not api_hash: if not api_id or not api_hash:
raise PermissionError( raise PermissionError(
"Your API ID or Hash cannot be empty or None. " "Your API ID or Hash cannot be empty or None. "
@ -92,7 +92,7 @@ class TelegramClient(TelegramBareClient):
raise ValueError( raise ValueError(
'The given session must be a str or a Session instance.') 'The given session must be a str or a Session instance.')
super().__init__(session, api_id, api_hash, proxy) super().__init__(session, api_id, api_hash, proxy, timeout=timeout)
# Safety across multiple threads (for the updates thread) # Safety across multiple threads (for the updates thread)
self._lock = RLock() self._lock = RLock()
@ -129,7 +129,7 @@ class TelegramClient(TelegramBareClient):
# region Connecting # region Connecting
def connect(self, timeout=timedelta(seconds=5), *args): def connect(self, *args):
"""Connects to the Telegram servers, executing authentication if """Connects to the Telegram servers, executing authentication if
required. Note that authenticating to the Telegram servers is required. Note that authenticating to the Telegram servers is
not the same as authenticating the desired user itself, which not the same as authenticating the desired user itself, which
@ -195,7 +195,10 @@ class TelegramClient(TelegramBareClient):
session = JsonSession(self.session) session = JsonSession(self.session)
session.server_address = dc.ip_address session.server_address = dc.ip_address
session.port = dc.port session.port = dc.port
client = TelegramBareClient(session, self.api_id, self.api_hash) client = TelegramBareClient(
session, self.api_id, self.api_hash,
timeout=self._timeout
)
client.connect(exported_auth=export_auth) client.connect(exported_auth=export_auth)
if not bypass_cache: if not bypass_cache:
@ -233,7 +236,7 @@ class TelegramClient(TelegramBareClient):
# region Telegram requests functions # region Telegram requests functions
def invoke(self, request, timeout=timedelta(seconds=5), *args): def invoke(self, request, *args):
"""Invokes (sends) a MTProtoRequest and returns (receives) its result. """Invokes (sends) a MTProtoRequest and returns (receives) its result.
An optional timeout can be specified to cancel the operation if no An optional timeout can be specified to cancel the operation if no
@ -244,18 +247,19 @@ class TelegramClient(TelegramBareClient):
if not issubclass(type(request), MTProtoRequest): if not issubclass(type(request), MTProtoRequest):
raise ValueError('You can only invoke MtProtoRequests') raise ValueError('You can only invoke MtProtoRequests')
if not self.sender: if not self._sender:
raise ValueError('You must be connected to invoke requests!') raise ValueError('You must be connected to invoke requests!')
if self._updates_thread_receiving.is_set(): if self._updates_thread_receiving.is_set():
self.sender.cancel_receive() self._sender.cancel_receive()
try: try:
self._lock.acquire() self._lock.acquire()
updates = [] if self._update_handlers else None updates = [] if self._update_handlers else None
result = super(TelegramClient, self).invoke( result = super(TelegramClient, self).invoke(
request, timeout=timeout, updates=updates) request, updates=updates
)
if updates: if updates:
for update in updates: for update in updates:
@ -271,13 +275,12 @@ class TelegramClient(TelegramBareClient):
.format(e.new_dc)) .format(e.new_dc))
self.reconnect(new_dc=e.new_dc) self.reconnect(new_dc=e.new_dc)
return self.invoke(request, timeout=timeout) return self.invoke(request)
finally: finally:
self._lock.release() self._lock.release()
def invoke_on_dc(self, request, dc_id, def invoke_on_dc(self, request, dc_id, reconnect=False):
timeout=timedelta(seconds=5), reconnect=False):
"""Invokes the given request on a different DC """Invokes the given request on a different DC
by making use of the exported MtProtoSenders. by making use of the exported MtProtoSenders.
@ -294,8 +297,7 @@ class TelegramClient(TelegramBareClient):
if reconnect: if reconnect:
raise raise
else: else:
return self.invoke_on_dc(request, dc_id, return self.invoke_on_dc(request, dc_id, reconnect=True)
timeout=timeout, reconnect=True)
# region Authorization requests # region Authorization requests
@ -374,7 +376,7 @@ class TelegramClient(TelegramBareClient):
Returns True if everything went okay.""" Returns True if everything went okay."""
# Special flag when logging out (so the ack request confirms it) # Special flag when logging out (so the ack request confirms it)
self.sender.logging_out = True self._sender.logging_out = True
try: try:
self.invoke(LogOutRequest()) self.invoke(LogOutRequest())
self.disconnect() self.disconnect()
@ -385,7 +387,7 @@ class TelegramClient(TelegramBareClient):
return True return True
except (RPCError, ConnectionError): except (RPCError, ConnectionError):
# Something happened when logging out, restore the state back # Something happened when logging out, restore the state back
self.sender.logging_out = False self._sender.logging_out = False
return False return False
def get_me(self): def get_me(self):
@ -756,7 +758,7 @@ class TelegramClient(TelegramBareClient):
def add_update_handler(self, handler): def add_update_handler(self, handler):
"""Adds an update handler (a function which takes a TLObject, """Adds an update handler (a function which takes a TLObject,
an update, as its parameter) and listens for updates""" an update, as its parameter) and listens for updates"""
if not self.sender: if not self._sender:
raise RuntimeError("You can't add update handlers until you've " raise RuntimeError("You can't add update handlers until you've "
"successfully connected to the server.") "successfully connected to the server.")
@ -791,7 +793,7 @@ class TelegramClient(TelegramBareClient):
else: else:
self._updates_thread_running.clear() self._updates_thread_running.clear()
if self._updates_thread_receiving.is_set(): if self._updates_thread_receiving.is_set():
self.sender.cancel_receive() self._sender.cancel_receive()
def _updates_thread_method(self): def _updates_thread_method(self):
"""This method will run until specified and listen for incoming updates""" """This method will run until specified and listen for incoming updates"""
@ -817,7 +819,7 @@ class TelegramClient(TelegramBareClient):
self._next_ping_at = time() + self.ping_interval self._next_ping_at = time() + self.ping_interval
self.invoke(PingRequest(utils.generate_random_long())) self.invoke(PingRequest(utils.generate_random_long()))
updates = self.sender.receive_updates(timeout=timeout) updates = self._sender.receive_updates(timeout=timeout)
self._updates_thread_receiving.clear() self._updates_thread_receiving.clear()
self._logger.info( self._logger.info(
@ -848,9 +850,9 @@ class TelegramClient(TelegramBareClient):
except OSError: except OSError:
self._logger.warning('OSError on updates thread, %s logging out', self._logger.warning('OSError on updates thread, %s logging out',
'was' if self.sender.logging_out else 'was not') 'was' if self._sender.logging_out else 'was not')
if self.sender.logging_out: if self._sender.logging_out:
# This error is okay when logging out, means we got disconnected # This error is okay when logging out, means we got disconnected
# TODO Not sure why this happens because we call disconnect()... # TODO Not sure why this happens because we call disconnect()...
self._set_updates_thread(running=False) self._set_updates_thread(running=False)