Use asyncio.BufferedProtocol in sender

This commit is contained in:
Lonami Exo 2024-10-18 19:07:45 +02:00
parent c588c74c08
commit 0e5ea59ecf
3 changed files with 88 additions and 60 deletions

View File

@ -1,5 +1,4 @@
from .sender import ( from .sender import (
MAXIMUM_DATA,
NO_PING_DISCONNECT, NO_PING_DISCONNECT,
PING_DELAY, PING_DELAY,
AsyncReader, AsyncReader,
@ -10,7 +9,6 @@ from .sender import (
) )
__all__ = [ __all__ = [
"MAXIMUM_DATA",
"NO_PING_DISCONNECT", "NO_PING_DISCONNECT",
"PING_DELAY", "PING_DELAY",
"AsyncReader", "AsyncReader",

View File

@ -0,0 +1,59 @@
import asyncio
from ..mtproto import (
MissingBytesError,
Transport,
)
MAXIMUM_DATA = (1024 * 1024) + (8 * 1024)
class BufferedTransportProtocol(asyncio.BufferedProtocol):
__slots__ = (
"_transport",
"_buffer",
"_buffer_head",
"_packets",
"_output",
"_closed",
)
def __init__(self, transport: Transport):
self._transport = transport
self._buffer = bytearray(MAXIMUM_DATA)
self._buffer_head = 0
self._packets: asyncio.Queue[bytes] = asyncio.Queue()
self._output = bytearray()
self._closed = asyncio.Event()
# Method overrides
def get_buffer(self, sizehint):
return self._buffer
def buffer_updated(self, nbytes):
self._buffer_head += nbytes
while self._buffer_head:
self._output.clear()
try:
n = self._transport.unpack(
memoryview(self._buffer)[: self._buffer_head], self._output
)
except MissingBytesError as e:
print(e)
return
else:
del self._buffer[:n]
self._buffer += bytes(n)
self._buffer_head -= n
self._packets.put_nowait(bytes(self._output))
def connection_lost(self, exc):
self._closed.set()
# Custom methods
def wait_closed(self):
return self._closed.wait()
def wait_packet(self):
return self._packets.get()

View File

@ -10,11 +10,11 @@ from typing import Generic, Optional, Protocol, Type, TypeVar
from typing_extensions import Self from typing_extensions import Self
from .protocol import BufferedTransportProtocol
from ..crypto import AuthKey from ..crypto import AuthKey
from ..mtproto import ( from ..mtproto import (
BadMessageError, BadMessageError,
Encrypted, Encrypted,
MissingBytesError,
MsgId, MsgId,
Mtp, Mtp,
Plain, Plain,
@ -31,7 +31,6 @@ from ..tl.mtproto.functions import ping_delay_disconnect
from ..tl.types import UpdateDeleteMessages, UpdateShort from ..tl.types import UpdateDeleteMessages, UpdateShort
from ..tl.types.messages import AffectedFoundMessages, AffectedHistory, AffectedMessages from ..tl.types.messages import AffectedFoundMessages, AffectedHistory, AffectedMessages
MAXIMUM_DATA = (1024 * 1024) + (8 * 1024)
PING_DELAY = 60 PING_DELAY = 60
@ -164,16 +163,13 @@ class Sender:
addr: str addr: str
lock: Lock lock: Lock
_logger: logging.Logger _logger: logging.Logger
_reader: AsyncReader _connection: asyncio.Transport
_writer: AsyncWriter
_transport: Transport _transport: Transport
_protocol: BufferedTransportProtocol
_mtp: Mtp _mtp: Mtp
_mtp_buffer: bytearray
_requests: list[Request[object]] _requests: list[Request[object]]
_request_event: Event _request_event: Event
_next_ping: float _next_ping: float
_read_buffer: bytearray
_write_drain_pending: bool
_step_counter: int _step_counter: int
@classmethod @classmethod
@ -188,29 +184,29 @@ class Sender:
base_logger: logging.Logger, base_logger: logging.Logger,
) -> Self: ) -> Self:
ip, port = addr.split(":") ip, port = addr.split(":")
reader, writer = await connector(ip, int(port)) # TODO BRING BACK SUPPORT FOR connector
connection, protocol = await asyncio.get_running_loop().create_connection(
lambda: BufferedTransportProtocol(transport), ip, int(port)
)
return cls( return cls(
dc_id=dc_id, dc_id=dc_id,
addr=addr, addr=addr,
lock=Lock(), lock=Lock(),
_logger=base_logger.getChild("mtsender"), _logger=base_logger.getChild("mtsender"),
_reader=reader, _connection=connection,
_writer=writer,
_transport=transport, _transport=transport,
_protocol=protocol,
_mtp=mtp, _mtp=mtp,
_mtp_buffer=bytearray(),
_requests=[], _requests=[],
_request_event=Event(), _request_event=Event(),
_next_ping=asyncio.get_running_loop().time() + PING_DELAY, _next_ping=asyncio.get_running_loop().time() + PING_DELAY,
_read_buffer=bytearray(),
_write_drain_pending=False,
_step_counter=0, _step_counter=0,
) )
async def disconnect(self) -> None: async def disconnect(self) -> None:
self._writer.close() self._connection.close()
await self._writer.wait_closed() await self._protocol.wait_closed()
def enqueue(self, request: RemoteCall[Return]) -> Future[bytes]: def enqueue(self, request: RemoteCall[Return]) -> Future[bytes]:
rx = self._enqueue_body(bytes(request)) rx = self._enqueue_body(bytes(request))
@ -251,14 +247,20 @@ class Sender:
async def _step(self) -> list[Updates]: async def _step(self) -> list[Updates]:
self._try_fill_write() self._try_fill_write()
self._connection.resume_reading()
recv_req = asyncio.create_task(self._request_event.wait()) recv_req = asyncio.create_task(self._request_event.wait())
recv_data = asyncio.create_task(self._reader.read(MAXIMUM_DATA)) recv_data = asyncio.create_task(self._protocol.wait_packet())
send_data = asyncio.create_task(self._do_send()) conn_lost = asyncio.create_task(self._protocol.wait_closed())
done, pending = await asyncio.wait( done, pending = await asyncio.wait(
(recv_req, recv_data, send_data), (
recv_req,
recv_data,
conn_lost,
),
timeout=self._next_ping - asyncio.get_running_loop().time(), timeout=self._next_ping - asyncio.get_running_loop().time(),
return_when=FIRST_COMPLETED, return_when=FIRST_COMPLETED,
) )
self._connection.pause_reading()
if pending: if pending:
for task in pending: for task in pending:
@ -270,24 +272,13 @@ class Sender:
self._request_event.clear() self._request_event.clear()
if recv_data in done: if recv_data in done:
result = self._on_net_read(recv_data.result()) result = self._on_net_read(recv_data.result())
if send_data in done: if conn_lost in done:
self._on_net_write() raise ConnectionResetError
if not done: if not done:
self._on_ping_timeout() self._on_ping_timeout()
return result return result
async def _do_send(self) -> None:
if self._write_drain_pending:
await self._writer.drain()
self._write_drain_pending = False
else:
# Never return
await asyncio.get_running_loop().create_future()
def _try_fill_write(self) -> None: def _try_fill_write(self) -> None:
if self._write_drain_pending:
return
for request in self._requests: for request in self._requests:
if isinstance(request.state, NotSerialized): if isinstance(request.state, NotSerialized):
if (msg_id := self._mtp.push(request.body)) is not None: if (msg_id := self._mtp.push(request.body)) is not None:
@ -298,37 +289,17 @@ class Sender:
result = self._mtp.finalize() result = self._mtp.finalize()
if result: if result:
container_msg_id, mtp_buffer = result container_msg_id, mtp_buffer = result
self._transport.pack(mtp_buffer, self._connection.write)
for request in self._requests: for request in self._requests:
if isinstance(request.state, Serialized): if isinstance(request.state, Serialized):
request.state.container_msg_id = container_msg_id request.state = Sent(request.state.msg_id, container_msg_id)
self._transport.pack(mtp_buffer, self._writer.write)
self._write_drain_pending = True
def _on_net_read(self, read_buffer: bytes) -> list[Updates]:
if not read_buffer:
raise ConnectionResetError("read 0 bytes")
self._read_buffer += read_buffer
def _on_net_read(self, mtp_buffer: bytes) -> list[Updates]:
updates: list[Updates] = [] updates: list[Updates] = []
while self._read_buffer: self._process_mtp_buffer(mtp_buffer, updates)
self._mtp_buffer.clear()
try:
n = self._transport.unpack(self._read_buffer, self._mtp_buffer)
except MissingBytesError:
break
else:
del self._read_buffer[:n]
self._process_mtp_buffer(updates)
return updates return updates
def _on_net_write(self) -> None:
for req in self._requests:
if isinstance(req.state, Serialized):
req.state = Sent(req.state.msg_id, req.state.container_msg_id)
def _on_ping_timeout(self) -> None: def _on_ping_timeout(self) -> None:
ping_id = generate_random_id() ping_id = generate_random_id()
self._enqueue_body( self._enqueue_body(
@ -340,8 +311,8 @@ class Sender:
) )
self._next_ping = asyncio.get_running_loop().time() + PING_DELAY self._next_ping = asyncio.get_running_loop().time() + PING_DELAY
def _process_mtp_buffer(self, updates: list[Updates]) -> None: def _process_mtp_buffer(self, mtp_buffer: bytes, updates: list[Updates]) -> None:
results = self._mtp.deserialize(self._mtp_buffer) results = self._mtp.deserialize(mtp_buffer)
for result in results: for result in results:
if isinstance(result, Update): if isinstance(result, Update):