Create a Connection only once and avoid no-op if was "connected"

This commit is contained in:
Lonami Exo 2017-09-21 13:43:33 +02:00
parent 4777b8dad4
commit 2b2da843a1
6 changed files with 38 additions and 44 deletions

View File

@ -13,9 +13,9 @@ class TcpClient:
self._closing_lock = Lock() self._closing_lock = Lock()
if isinstance(timeout, timedelta): if isinstance(timeout, timedelta):
self._timeout = timeout.seconds self.timeout = timeout.seconds
elif isinstance(timeout, int) or isinstance(timeout, float): elif isinstance(timeout, int) or isinstance(timeout, float):
self._timeout = float(timeout) self.timeout = float(timeout)
else: else:
raise ValueError('Invalid timeout type', type(timeout)) raise ValueError('Invalid timeout type', type(timeout))
@ -30,7 +30,7 @@ class TcpClient:
else: # tuple, list, etc. else: # tuple, list, etc.
self._socket.set_proxy(*self._proxy) self._socket.set_proxy(*self._proxy)
self._socket.settimeout(self._timeout) self._socket.settimeout(self.timeout)
def connect(self, ip, port): def connect(self, ip, port):
"""Connects to the specified IP and port number. """Connects to the specified IP and port number.
@ -81,6 +81,8 @@ class TcpClient:
def write(self, data): def write(self, data):
"""Writes (sends) the specified bytes to the connected peer""" """Writes (sends) the specified bytes to the connected peer"""
if self._socket is None:
raise ConnectionResetError()
# TODO Timeout may be an issue when sending the data, Changed in v3.5: # TODO Timeout may be an issue when sending the data, Changed in v3.5:
# The socket timeout is now the maximum total duration to send all data. # The socket timeout is now the maximum total duration to send all data.
@ -105,6 +107,9 @@ class TcpClient:
and it's waiting for more, the timeout will NOT cancel the and it's waiting for more, the timeout will NOT cancel the
operation. Set to None for no timeout operation. Set to None for no timeout
""" """
if self._socket is None:
raise ConnectionResetError()
# TODO Remove the timeout from this method, always use previous one # TODO Remove the timeout from this method, always use previous one
with BufferedWriter(BytesIO(), buffer_size=size) as buffer: with BufferedWriter(BytesIO(), buffer_size=size) as buffer:
bytes_left = size bytes_left = size

View File

@ -2,6 +2,8 @@ import os
import time import time
from hashlib import sha1 from hashlib import sha1
import errno
from .. import helpers as utils from .. import helpers as utils
from ..crypto import AES, AuthKey, Factorization from ..crypto import AES, AuthKey, Factorization
from ..crypto import rsa from ..crypto import rsa
@ -30,7 +32,6 @@ def _do_authentication(connection):
time offset. time offset.
""" """
sender = MtProtoPlainSender(connection) sender = MtProtoPlainSender(connection)
sender.connect()
# Step 1 sending: PQ Request # Step 1 sending: PQ Request
nonce = os.urandom(16) nonce = os.urandom(16)

View File

@ -93,6 +93,9 @@ class Connection:
elif self._mode == ConnectionMode.TCP_OBFUSCATED: elif self._mode == ConnectionMode.TCP_OBFUSCATED:
self._setup_obfuscation() self._setup_obfuscation()
def get_timeout(self):
return self.conn.timeout
def _setup_obfuscation(self): def _setup_obfuscation(self):
# Obfuscated messages secrets cannot start with any of these # Obfuscated messages secrets cannot start with any of these
keywords = (b'PVrG', b'GET ', b'POST', b'\xee' * 4) keywords = (b'PVrG', b'GET ', b'POST', b'\xee' * 4)

View File

@ -17,12 +17,12 @@ class MtProtoSender:
(https://core.telegram.org/mtproto/description) (https://core.telegram.org/mtproto/description)
""" """
def __init__(self, connection, session): def __init__(self, session, connection):
"""Creates a new MtProtoSender configured to send messages through """Creates a new MtProtoSender configured to send messages through
'connection' and using the parameters from 'session'. 'connection' and using the parameters from 'session'.
""" """
self.connection = connection
self.session = session self.session = session
self.connection = connection
self._logger = logging.getLogger(__name__) self._logger = logging.getLogger(__name__)
self._need_confirmation = [] # Message IDs that need confirmation self._need_confirmation = [] # Message IDs that need confirmation
@ -47,6 +47,9 @@ class MtProtoSender:
def disconnect(self): def disconnect(self):
"""Disconnects from the server""" """Disconnects from the server"""
self.connection.close() self.connection.close()
self._need_confirmation.clear()
self._clear_all_pending()
self.logging_out = False
# region Send and receive # region Send and receive
@ -97,9 +100,7 @@ class MtProtoSender:
# "This packet should be skipped"; since this may have # "This packet should be skipped"; since this may have
# been a result for a request, invalidate every request # been a result for a request, invalidate every request
# and just re-invoke them to avoid problems # and just re-invoke them to avoid problems
for r in self._pending_receive: self._clear_all_pending()
r.confirm_received.set()
self._pending_receive.clear()
return return
message, remote_msg_id, remote_seq = self._decode_msg(body) message, remote_msg_id, remote_seq = self._decode_msg(body)
@ -245,6 +246,11 @@ class MtProtoSender:
if self._pending_receive[i].request_msg_id == request_msg_id: if self._pending_receive[i].request_msg_id == request_msg_id:
return self._pending_receive.pop(i) return self._pending_receive.pop(i)
def _clear_all_pending(self):
for r in self._pending_receive:
r.confirm_received.set()
self._pending_receive.clear()
def _handle_pong(self, msg_id, sequence, reader): def _handle_pong(self, msg_id, sequence, reader):
self._logger.debug('Handling pong') self._logger.debug('Handling pong')
reader.read_int(signed=False) # code reader.read_int(signed=False) # code

View File

@ -72,11 +72,13 @@ class TelegramBareClient:
self.api_id = int(api_id) self.api_id = int(api_id)
self.api_hash = api_hash self.api_hash = api_hash
if self.api_id < 20: # official apps must use obfuscated if self.api_id < 20: # official apps must use obfuscated
self._connection_mode = ConnectionMode.TCP_OBFUSCATED connection_mode = ConnectionMode.TCP_OBFUSCATED
else:
self._connection_mode = connection_mode self._sender = MtProtoSender(self.session, Connection(
self.proxy = proxy self.session.server_address, self.session.port,
self._timeout = timeout mode=connection_mode, proxy=proxy, timeout=timeout
))
self._logger = logging.getLogger(__name__) self._logger = logging.getLogger(__name__)
# Cache "exported" senders 'dc_id: TelegramBareClient' and # Cache "exported" senders 'dc_id: TelegramBareClient' and
@ -88,9 +90,6 @@ class TelegramBareClient:
# One may change self.updates.enabled at any later point. # One may change self.updates.enabled at any later point.
self.updates = UpdateState(process_updates) self.updates = UpdateState(process_updates)
# These will be set later
self._sender = None
# endregion # endregion
# region Connecting # region Connecting
@ -104,21 +103,14 @@ class TelegramBareClient:
If 'exported_auth' is not None, it will be used instead to If 'exported_auth' is not None, it will be used instead to
determine the authorization key for the current session. determine the authorization key for the current session.
""" """
if self.is_connected():
return True
connection = Connection(
self.session.server_address, self.session.port,
mode=self._connection_mode, proxy=self.proxy, timeout=self._timeout
)
try: try:
self._sender.connect()
if not self.session.auth_key: if not self.session.auth_key:
# New key, we need to tell the server we're going to use # New key, we need to tell the server we're going to use
# the latest layer # the latest layer
try: try:
self.session.auth_key, self.session.time_offset = \ self.session.auth_key, self.session.time_offset = \
authenticator.do_authentication(connection) authenticator.do_authentication(self._sender.connection)
except BrokenAuthKeyError: except BrokenAuthKeyError:
return False return False
@ -128,8 +120,6 @@ class TelegramBareClient:
else: else:
init_connection = self.session.layer != LAYER init_connection = self.session.layer != LAYER
self._sender = MtProtoSender(connection, self.session)
self._sender.connect()
if init_connection: if init_connection:
if exported_auth is not None: if exported_auth is not None:
@ -166,7 +156,7 @@ class TelegramBareClient:
return False return False
def is_connected(self): def is_connected(self):
return self._sender is not None and self._sender.is_connected() return self._sender.is_connected()
def _init_connection(self, query=None): def _init_connection(self, query=None):
result = self(InvokeWithLayerRequest(LAYER, InitConnectionRequest( result = self(InvokeWithLayerRequest(LAYER, InitConnectionRequest(
@ -185,9 +175,7 @@ class TelegramBareClient:
def disconnect(self): def disconnect(self):
"""Disconnects from the Telegram server""" """Disconnects from the Telegram server"""
if self._sender: self._sender.disconnect()
self._sender.disconnect()
self._sender = None
def reconnect(self, new_dc=None): def reconnect(self, new_dc=None):
"""Disconnects and connects again (effectively reconnecting). """Disconnects and connects again (effectively reconnecting).
@ -274,7 +262,7 @@ class TelegramBareClient:
session.port = dc.port session.port = dc.port
client = TelegramBareClient( client = TelegramBareClient(
session, self.api_id, self.api_hash, session, self.api_id, self.api_hash,
timeout=self._timeout timeout=self._connection.get_timeout()
) )
client.connect(exported_auth=export_auth) client.connect(exported_auth=export_auth)
@ -300,9 +288,6 @@ class TelegramBareClient:
if not isinstance(request, TLObject) and not request.content_related: if not isinstance(request, TLObject) and not request.content_related:
raise ValueError('You can only invoke requests, not types!') raise ValueError('You can only invoke requests, not types!')
if not self._sender:
raise ValueError('You must be connected to invoke requests!')
if retries <= 0: if retries <= 0:
raise ValueError('Number of retries reached 0.') raise ValueError('Number of retries reached 0.')

View File

@ -148,9 +148,6 @@ class TelegramClient(TelegramBareClient):
exported_auth is meant for internal purposes and can be ignored. exported_auth is meant for internal purposes and can be ignored.
""" """
if self._sender and self._sender.is_connected():
return
if socks and self._recv_thread: if socks and self._recv_thread:
# Treat proxy errors specially since they're not related to # Treat proxy errors specially since they're not related to
# Telegram itself, but rather to the proxy. If any happens on # Telegram itself, but rather to the proxy. If any happens on
@ -173,7 +170,7 @@ class TelegramClient(TelegramBareClient):
# read constantly or not for updates needs to be known before hand, # read constantly or not for updates needs to be known before hand,
# and further updates won't be able to be added unless allowing to # and further updates won't be able to be added unless allowing to
# switch the mode on the fly. # switch the mode on the fly.
if ok: if ok and self._recv_thread is None:
self._recv_thread = Thread( self._recv_thread = Thread(
name='ReadThread', daemon=True, name='ReadThread', daemon=True,
target=self._recv_thread_impl target=self._recv_thread_impl
@ -187,9 +184,6 @@ class TelegramClient(TelegramBareClient):
def disconnect(self): def disconnect(self):
"""Disconnects from the Telegram server """Disconnects from the Telegram server
and stops all the spawned threads""" and stops all the spawned threads"""
if not self._sender or not self._sender.is_connected():
return
# The existing thread will close eventually, since it's # The existing thread will close eventually, since it's
# only running while the MtProtoSender.is_connected() # only running while the MtProtoSender.is_connected()
self._recv_thread = None self._recv_thread = None
@ -1035,7 +1029,7 @@ class TelegramClient(TelegramBareClient):
# #
# This way, sending and receiving will be completely independent. # This way, sending and receiving will be completely independent.
def _recv_thread_impl(self): def _recv_thread_impl(self):
while self._sender and self._sender.is_connected(): while self._sender.is_connected():
try: try:
if datetime.now() > self._last_ping + self._ping_delay: if datetime.now() > self._last_ping + self._ping_delay:
self._sender.send(PingRequest( self._sender.send(PingRequest(