Fix basic requests sending and receiving

This commit is contained in:
Lonami Exo 2018-06-06 21:42:48 +02:00
parent e469258ab9
commit 9477c75fce
3 changed files with 100 additions and 37 deletions

View File

@ -73,12 +73,13 @@ class TcpClient:
if self._socket is None:
self._socket = self._create_socket(mode, self.proxy)
asyncio.wait_for(self._loop.sock_connect(self._socket, address),
self.timeout, loop=self._loop)
await asyncio.wait_for(self._loop.sock_connect(self._socket, address),
self.timeout, loop=self._loop)
@property
def is_connected(self):
"""Determines whether the client is connected or not."""
# TODO fileno() is >= 0 even before calling sock_connect!
return self._socket is not None and self._socket.fileno() >= 0
def close(self):
@ -123,7 +124,7 @@ class TcpClient:
timeout=self.timeout,
loop=self._loop
)
if not partial == 0:
if not partial:
raise ConnectionResetError()
buffer.write(partial)

View File

@ -22,7 +22,7 @@ class ConnectionTcpFull(Connection):
async def connect(self, ip, port):
try:
self.conn.connect(ip, port)
await self.conn.connect(ip, port)
except OSError as e:
if e.errno == errno.EISCONN:
return # Already connected, no need to re-set everything up
@ -35,7 +35,7 @@ class ConnectionTcpFull(Connection):
return self.conn.timeout
def is_connected(self):
return self.conn.connected
return self.conn.is_connected
async def close(self):
self.conn.close()
@ -44,10 +44,11 @@ class ConnectionTcpFull(Connection):
return ConnectionTcpFull(self._proxy, self._timeout)
async def recv(self):
packet_len_seq = self.read(8) # 4 and 4
packet_len_seq = await self.read(8) # 4 and 4
packet_len, seq = struct.unpack('<ii', packet_len_seq)
body = self.read(packet_len - 12)
checksum = struct.unpack('<I', self.read(4))[0]
body = await self.read(packet_len - 8)
checksum = struct.unpack('<I', body[-4:])[0]
body = body[:-4]
valid_checksum = crc32(packet_len_seq + body)
if checksum != valid_checksum:
@ -62,4 +63,4 @@ class ConnectionTcpFull(Connection):
data = struct.pack('<ii', length, self._send_counter) + message
crc = struct.pack('<I', crc32(data))
self._send_counter += 1
self.write(data + crc)
await self.write(data + crc)

View File

@ -3,6 +3,7 @@ import logging
from .connection import ConnectionTcpFull
from .. import helpers
from ..errors import rpc_message_to_error
from ..extensions import BinaryReader
from ..tl import TLMessage, MessageContainer, GzipPacked
from ..tl.types import (
@ -28,6 +29,10 @@ class MTProtoSender:
self._send_lock = asyncio.Lock()
self._recv_lock = asyncio.Lock()
# We need to join the loops upon disconnection
self._send_loop_handle = None
self._recv_loop_handle = None
# Sending something shouldn't block
self._send_queue = asyncio.Queue()
@ -56,9 +61,11 @@ class MTProtoSender:
# Public API
async def connect(self, ip, port):
self._user_connected = True
async with self._send_lock:
await self._connection.connect(ip, port)
self._user_connected = True
self._send_loop_handle = asyncio.ensure_future(self._send_loop())
self._recv_loop_handle = asyncio.ensure_future(self._recv_loop())
async def disconnect(self):
self._user_connected = False
@ -67,6 +74,9 @@ class MTProtoSender:
await self._connection.close()
except:
__log__.exception('Ignoring exception upon disconnection')
finally:
self._send_loop_handle.cancel()
self._recv_loop_handle.cancel()
async def send(self, request):
# TODO Should the asyncio.Future creation belong here?
@ -97,48 +107,99 @@ class MTProtoSender:
message, remote_msg_id, remote_seq = helpers.unpack_message(
self.session, body)
self._pending_ack.add(remote_msg_id)
with BinaryReader(message) as reader:
code = reader.read_int(signed=False)
reader.seek(-4)
handler = self._handlers.get(code)
if handler:
handler(remote_msg_id, remote_seq, reader)
else:
pass # TODO Process updates
await self._process_message(remote_msg_id, remote_seq, reader)
# Response Handlers
def _handle_rpc_result(self, msg_id, seq, reader):
async def _process_message(self, msg_id, seq, reader):
self._pending_ack.add(msg_id)
code = reader.read_int(signed=False)
reader.seek(-4)
handler = self._handlers.get(code)
if handler:
await handler(msg_id, seq, reader)
else:
pass # TODO Process updates and their entities
async def _handle_rpc_result(self, msg_id, seq, reader):
# TODO Don't make this a special case
reader.read_int(signed=False) # code
message_id = reader.read_long()
inner_code = reader.read_int(signed=False)
reader.seek(-4)
message = self._pending_messages.pop(message_id)
if inner_code == 0x2144ca19: # RPC Error
reader.seek(4)
if self.session.report_errors and message:
error = rpc_message_to_error(
reader.read_int(), reader.tgread_string(),
report_method=type(message.request).CONSTRUCTOR_ID
)
else:
error = rpc_message_to_error(
reader.read_int(), reader.tgread_string()
)
# TODO Acknowledge that we received the error request_id
# TODO Set message.request exception
elif message:
# TODO Make on_response result.set_result() instead replacing it
if inner_code == GzipPacked.CONSTRUCTOR_ID:
with BinaryReader(GzipPacked.read(reader)) as compressed_reader:
message.on_response(compressed_reader)
else:
message.on_response(reader)
# TODO Process possible entities
# TODO Try reading an object
async def _handle_container(self, msg_id, seq, reader):
for inner_msg_id, _, inner_len in MessageContainer.iter_read(reader):
next_position = reader.tell_position() + inner_len
await self._process_message(inner_msg_id, seq, reader)
reader.set_position(next_position) # Ensure reading correctly
async def _handle_gzip_packed(self, msg_id, seq, reader):
raise NotImplementedError
def _handle_container(self, msg_id, seq, reader):
async def _handle_pong(self, msg_id, seq, reader):
raise NotImplementedError
def _handle_gzip_packed(self, msg_id, seq, reader):
async def _handle_bad_server_salt(self, msg_id, seq, reader):
bad_salt = reader.tgread_object()
self.session.salt = bad_salt.new_server_salt
self.session.save()
# "the bad_server_salt response is received with the
# correct salt, and the message is to be re-sent with it"
await self._send_queue.put(self._pending_messages[bad_salt.bad_msg_id])
async def _handle_bad_notification(self, msg_id, seq, reader):
raise NotImplementedError
def _handle_pong(self, msg_id, seq, reader):
async def _handle_detailed_info(self, msg_id, seq, reader):
raise NotImplementedError
def _handle_bad_server_salt(self, msg_id, seq, reader):
async def _handle_new_detailed_info(self, msg_id, seq, reader):
raise NotImplementedError
def _handle_bad_notification(self, msg_id, seq, reader):
raise NotImplementedError
async def _handle_new_session_created(self, msg_id, seq, reader):
# TODO https://goo.gl/LMyN7A
new_session = reader.tgread_object()
self.session.salt = new_session.server_salt
def _handle_detailed_info(self, msg_id, seq, reader):
raise NotImplementedError
async def _handle_ack(self, msg_id, seq, reader):
# Ignore every ack request *unless* when logging out, when it's
# when it seems to only make sense. We also need to set a non-None
# result since Telegram doesn't send the response for these.
for msg_id in reader.tgread_object().msg_ids:
# TODO pop msg_id if of type LogOutRequest, and confirm it
pass
def _handle_new_detailed_info(self, msg_id, seq, reader):
raise NotImplementedError
return True
def _handle_new_session_created(self, msg_id, seq, reader):
raise NotImplementedError
def _handle_ack(self, msg_id, seq, reader):
raise NotImplementedError
def _handle_future_salts(self, msg_id, seq, reader):
async def _handle_future_salts(self, msg_id, seq, reader):
raise NotImplementedError