Use asyncio.BufferedProtocol in sender

This commit is contained in:
Lonami Exo 2024-10-18 19:07:45 +02:00
parent c588c74c08
commit 0e5ea59ecf
3 changed files with 88 additions and 60 deletions

View File

@ -1,5 +1,4 @@
from .sender import (
MAXIMUM_DATA,
NO_PING_DISCONNECT,
PING_DELAY,
AsyncReader,
@ -10,7 +9,6 @@ from .sender import (
)
__all__ = [
"MAXIMUM_DATA",
"NO_PING_DISCONNECT",
"PING_DELAY",
"AsyncReader",

View File

@ -0,0 +1,59 @@
import asyncio
from ..mtproto import (
MissingBytesError,
Transport,
)
MAXIMUM_DATA = (1024 * 1024) + (8 * 1024)
class BufferedTransportProtocol(asyncio.BufferedProtocol):
__slots__ = (
"_transport",
"_buffer",
"_buffer_head",
"_packets",
"_output",
"_closed",
)
def __init__(self, transport: Transport):
self._transport = transport
self._buffer = bytearray(MAXIMUM_DATA)
self._buffer_head = 0
self._packets: asyncio.Queue[bytes] = asyncio.Queue()
self._output = bytearray()
self._closed = asyncio.Event()
# Method overrides
def get_buffer(self, sizehint):
return self._buffer
def buffer_updated(self, nbytes):
self._buffer_head += nbytes
while self._buffer_head:
self._output.clear()
try:
n = self._transport.unpack(
memoryview(self._buffer)[: self._buffer_head], self._output
)
except MissingBytesError as e:
print(e)
return
else:
del self._buffer[:n]
self._buffer += bytes(n)
self._buffer_head -= n
self._packets.put_nowait(bytes(self._output))
def connection_lost(self, exc):
self._closed.set()
# Custom methods
def wait_closed(self):
return self._closed.wait()
def wait_packet(self):
return self._packets.get()

View File

@ -10,11 +10,11 @@ from typing import Generic, Optional, Protocol, Type, TypeVar
from typing_extensions import Self
from .protocol import BufferedTransportProtocol
from ..crypto import AuthKey
from ..mtproto import (
BadMessageError,
Encrypted,
MissingBytesError,
MsgId,
Mtp,
Plain,
@ -31,7 +31,6 @@ from ..tl.mtproto.functions import ping_delay_disconnect
from ..tl.types import UpdateDeleteMessages, UpdateShort
from ..tl.types.messages import AffectedFoundMessages, AffectedHistory, AffectedMessages
MAXIMUM_DATA = (1024 * 1024) + (8 * 1024)
PING_DELAY = 60
@ -164,16 +163,13 @@ class Sender:
addr: str
lock: Lock
_logger: logging.Logger
_reader: AsyncReader
_writer: AsyncWriter
_connection: asyncio.Transport
_transport: Transport
_protocol: BufferedTransportProtocol
_mtp: Mtp
_mtp_buffer: bytearray
_requests: list[Request[object]]
_request_event: Event
_next_ping: float
_read_buffer: bytearray
_write_drain_pending: bool
_step_counter: int
@classmethod
@ -188,29 +184,29 @@ class Sender:
base_logger: logging.Logger,
) -> Self:
ip, port = addr.split(":")
reader, writer = await connector(ip, int(port))
# TODO BRING BACK SUPPORT FOR connector
connection, protocol = await asyncio.get_running_loop().create_connection(
lambda: BufferedTransportProtocol(transport), ip, int(port)
)
return cls(
dc_id=dc_id,
addr=addr,
lock=Lock(),
_logger=base_logger.getChild("mtsender"),
_reader=reader,
_writer=writer,
_connection=connection,
_transport=transport,
_protocol=protocol,
_mtp=mtp,
_mtp_buffer=bytearray(),
_requests=[],
_request_event=Event(),
_next_ping=asyncio.get_running_loop().time() + PING_DELAY,
_read_buffer=bytearray(),
_write_drain_pending=False,
_step_counter=0,
)
async def disconnect(self) -> None:
self._writer.close()
await self._writer.wait_closed()
self._connection.close()
await self._protocol.wait_closed()
def enqueue(self, request: RemoteCall[Return]) -> Future[bytes]:
rx = self._enqueue_body(bytes(request))
@ -251,14 +247,20 @@ class Sender:
async def _step(self) -> list[Updates]:
self._try_fill_write()
self._connection.resume_reading()
recv_req = asyncio.create_task(self._request_event.wait())
recv_data = asyncio.create_task(self._reader.read(MAXIMUM_DATA))
send_data = asyncio.create_task(self._do_send())
recv_data = asyncio.create_task(self._protocol.wait_packet())
conn_lost = asyncio.create_task(self._protocol.wait_closed())
done, pending = await asyncio.wait(
(recv_req, recv_data, send_data),
(
recv_req,
recv_data,
conn_lost,
),
timeout=self._next_ping - asyncio.get_running_loop().time(),
return_when=FIRST_COMPLETED,
)
self._connection.pause_reading()
if pending:
for task in pending:
@ -270,24 +272,13 @@ class Sender:
self._request_event.clear()
if recv_data in done:
result = self._on_net_read(recv_data.result())
if send_data in done:
self._on_net_write()
if conn_lost in done:
raise ConnectionResetError
if not done:
self._on_ping_timeout()
return result
async def _do_send(self) -> None:
if self._write_drain_pending:
await self._writer.drain()
self._write_drain_pending = False
else:
# Never return
await asyncio.get_running_loop().create_future()
def _try_fill_write(self) -> None:
if self._write_drain_pending:
return
for request in self._requests:
if isinstance(request.state, NotSerialized):
if (msg_id := self._mtp.push(request.body)) is not None:
@ -298,37 +289,17 @@ class Sender:
result = self._mtp.finalize()
if result:
container_msg_id, mtp_buffer = result
self._transport.pack(mtp_buffer, self._connection.write)
for request in self._requests:
if isinstance(request.state, Serialized):
request.state.container_msg_id = container_msg_id
self._transport.pack(mtp_buffer, self._writer.write)
self._write_drain_pending = True
def _on_net_read(self, read_buffer: bytes) -> list[Updates]:
if not read_buffer:
raise ConnectionResetError("read 0 bytes")
self._read_buffer += read_buffer
request.state = Sent(request.state.msg_id, container_msg_id)
def _on_net_read(self, mtp_buffer: bytes) -> list[Updates]:
updates: list[Updates] = []
while self._read_buffer:
self._mtp_buffer.clear()
try:
n = self._transport.unpack(self._read_buffer, self._mtp_buffer)
except MissingBytesError:
break
else:
del self._read_buffer[:n]
self._process_mtp_buffer(updates)
self._process_mtp_buffer(mtp_buffer, updates)
return updates
def _on_net_write(self) -> None:
for req in self._requests:
if isinstance(req.state, Serialized):
req.state = Sent(req.state.msg_id, req.state.container_msg_id)
def _on_ping_timeout(self) -> None:
ping_id = generate_random_id()
self._enqueue_body(
@ -340,8 +311,8 @@ class Sender:
)
self._next_ping = asyncio.get_running_loop().time() + PING_DELAY
def _process_mtp_buffer(self, updates: list[Updates]) -> None:
results = self._mtp.deserialize(self._mtp_buffer)
def _process_mtp_buffer(self, mtp_buffer: bytes, updates: list[Updates]) -> None:
results = self._mtp.deserialize(mtp_buffer)
for result in results:
if isinstance(result, Update):