Use asyncio.BufferedProtocol in sender

This commit is contained in:
Lonami Exo 2024-10-18 19:07:45 +02:00 committed by Jahongir Qurbonov
parent e4e7681051
commit b0a06a97ae
3 changed files with 94 additions and 114 deletions

View File

@ -1,5 +1,4 @@
from .sender import ( from .sender import (
MAXIMUM_DATA,
NO_PING_DISCONNECT, NO_PING_DISCONNECT,
PING_DELAY, PING_DELAY,
AsyncReader, AsyncReader,
@ -10,7 +9,6 @@ from .sender import (
) )
__all__ = [ __all__ = [
"MAXIMUM_DATA",
"NO_PING_DISCONNECT", "NO_PING_DISCONNECT",
"PING_DELAY", "PING_DELAY",
"AsyncReader", "AsyncReader",

View File

@ -0,0 +1,43 @@
from __future__ import annotations
import weakref
from asyncio import BufferedProtocol, Event
from typing import TYPE_CHECKING, Literal
from typing_extensions import Buffer
if TYPE_CHECKING:
from .sender import Sender
class BufferedStreamingProtocol(BufferedProtocol):
__slots__ = ("_sender", "_closed")
def __init__(self, sender: Sender) -> None:
self._sender = weakref.ref(sender)
self._closed = Event()
@property
def sender(self) -> Sender:
if (sender := self._sender()) is None:
raise ValueError("Sender has been garbage-collected")
return sender
# Method overrides
def get_buffer(self, sizehint: int) -> Buffer:
return self.sender._read_buffer
def buffer_updated(self, nbytes: int) -> None:
self.sender._on_buffer_updated(nbytes)
def connection_lost(self, exc: Exception | None) -> None:
self._closed.set()
# Custom methods
def is_closed(self) -> bool:
return self._closed.is_set()
async def wait_closed(self) -> Literal[True]:
return await self._closed.wait()

View File

@ -3,7 +3,7 @@ import logging
import struct import struct
import time import time
from abc import ABC from abc import ABC
from asyncio import FIRST_COMPLETED, Event, Future, Lock from asyncio import Future, Lock, StreamReader, StreamWriter
from collections.abc import Iterator from collections.abc import Iterator
from dataclasses import dataclass from dataclasses import dataclass
from typing import Generic, Optional, Protocol, Type, TypeVar from typing import Generic, Optional, Protocol, Type, TypeVar
@ -30,6 +30,7 @@ from ..tl.core import Serializable
from ..tl.mtproto.functions import ping_delay_disconnect from ..tl.mtproto.functions import ping_delay_disconnect
from ..tl.types import UpdateDeleteMessages, UpdateShort from ..tl.types import UpdateDeleteMessages, UpdateShort
from ..tl.types.messages import AffectedFoundMessages, AffectedHistory, AffectedMessages from ..tl.types.messages import AffectedFoundMessages, AffectedHistory, AffectedMessages
from .protocol import BufferedStreamingProtocol
MAXIMUM_DATA = (1024 * 1024) + (8 * 1024) MAXIMUM_DATA = (1024 * 1024) + (8 * 1024)
@ -51,50 +52,6 @@ def generate_random_id() -> int:
return _last_id return _last_id
class AsyncReader(Protocol):
"""
A :class:`asyncio.StreamReader`-like class.
"""
async def read(self, n: int) -> bytes:
"""
Must behave like :meth:`asyncio.StreamReader.read`.
:param n:
Amount of bytes to read at most.
"""
raise NotImplementedError
class AsyncWriter(Protocol):
"""
A :class:`asyncio.StreamWriter`-like class.
"""
def write(self, data: bytes | bytearray | memoryview) -> None:
"""
Must behave like :meth:`asyncio.StreamWriter.write`.
:param data:
Data that must be entirely written or buffered until :meth:`drain` is called.
"""
async def drain(self) -> None:
"""
Must behave like :meth:`asyncio.StreamWriter.drain`.
"""
def close(self) -> None:
"""
Must behave like :meth:`asyncio.StreamWriter.close`.
"""
async def wait_closed(self) -> None:
"""
Must behave like :meth:`asyncio.StreamWriter.wait_closed`.
"""
class Connector(Protocol): class Connector(Protocol):
""" """
A *Connector* is any function that takes in the following two positional parameters as input: A *Connector* is any function that takes in the following two positional parameters as input:
@ -102,7 +59,7 @@ class Connector(Protocol):
* The ``ip`` address as a :class:`str`. This might be either a IPv4 or IPv6. * The ``ip`` address as a :class:`str`. This might be either a IPv4 or IPv6.
* The ``port`` as a :class:`int`. This will be a number below 2¹, often 443. * The ``port`` as a :class:`int`. This will be a number below 2¹, often 443.
and returns a :class:`tuple`\\ [:class:`AsyncReader`, :class:`AsyncWriter`]. and returns a :class:`tuple`\\ [:class:`StreamReader`, :class:`StreamWriter`].
You can use a custom connector to connect to Telegram through proxies. You can use a custom connector to connect to Telegram through proxies.
The library will only ever open remote connections through this function. The library will only ever open remote connections through this function.
@ -120,7 +77,7 @@ class Connector(Protocol):
The :doc:`/concepts/datacenters` concept has examples on how to combine proxy libraries with Telethon. The :doc:`/concepts/datacenters` concept has examples on how to combine proxy libraries with Telethon.
""" """
async def __call__(self, ip: str, port: int) -> tuple[AsyncReader, AsyncWriter]: async def __call__(self, ip: str, port: int) -> tuple[StreamReader, StreamWriter]:
raise NotImplementedError raise NotImplementedError
@ -164,18 +121,18 @@ class Sender:
addr: str addr: str
lock: Lock lock: Lock
_logger: logging.Logger _logger: logging.Logger
_reader: AsyncReader _reader: StreamReader
_writer: AsyncWriter _writer: StreamWriter
_transport: Transport _transport: Transport
_mtp: Mtp _mtp: Mtp
_mtp_buffer: bytearray _mtp_buffer: bytearray
_updates: list[Updates] _updates: list[Updates]
_requests: list[Request[object]] _requests: list[Request[object]]
_request_event: Event _read_buffer_head: int
_next_ping: float
_read_buffer: bytearray _read_buffer: bytearray
_write_drain_pending: bool _response_state: asyncio.Event
_step_counter: int _step_counter: int
_protocol: BufferedStreamingProtocol | None = None
@classmethod @classmethod
async def connect( async def connect(
@ -191,7 +148,7 @@ class Sender:
ip, port = addr.split(":") ip, port = addr.split(":")
reader, writer = await connector(ip, int(port)) reader, writer = await connector(ip, int(port))
return cls( sender = cls(
dc_id=dc_id, dc_id=dc_id,
addr=addr, addr=addr,
lock=Lock(), lock=Lock(),
@ -203,20 +160,25 @@ class Sender:
_mtp_buffer=bytearray(), _mtp_buffer=bytearray(),
_updates=[], _updates=[],
_requests=[], _requests=[],
_request_event=Event(), _read_buffer_head=0,
_next_ping=asyncio.get_running_loop().time() + PING_DELAY, _read_buffer=bytearray(MAXIMUM_DATA),
_read_buffer=bytearray(), _response_state=asyncio.Event(),
_write_drain_pending=False,
_step_counter=0, _step_counter=0,
) )
protocol = BufferedStreamingProtocol(sender)
sender._writer.transport.set_protocol(protocol)
sender._protocol = protocol
return sender
async def disconnect(self) -> None: async def disconnect(self) -> None:
assert self._protocol
self._writer.close() self._writer.close()
await self._writer.wait_closed() await self._protocol.wait_closed()
def enqueue(self, request: RemoteCall[Return]) -> Future[bytes]: def enqueue(self, request: RemoteCall[Return]) -> Future[bytes]:
rx = self._enqueue_body(bytes(request)) rx = self._enqueue_body(bytes(request))
self._request_event.set()
return rx return rx
async def invoke(self, request: RemoteCall[Return]) -> bytes: async def invoke(self, request: RemoteCall[Return]) -> bytes:
@ -254,43 +216,15 @@ class Sender:
return updates return updates
async def _step(self) -> None: async def _step(self) -> None:
assert self._protocol
if self._protocol.is_closed():
raise ConnectionResetError
self._response_state.clear()
self._try_fill_write() self._try_fill_write()
await self._wait_response()
recv_req = asyncio.create_task(self._request_event.wait())
recv_data = asyncio.create_task(self._reader.read(MAXIMUM_DATA))
send_data = asyncio.create_task(self._do_send())
done, pending = await asyncio.wait(
(recv_req, recv_data, send_data),
timeout=self._next_ping - asyncio.get_running_loop().time(),
return_when=FIRST_COMPLETED,
)
if pending:
for task in pending:
task.cancel()
await asyncio.wait(pending)
if recv_req in done:
self._request_event.clear()
if recv_data in done:
self._on_net_read(recv_data.result())
if send_data in done:
self._on_net_write()
if not done:
self._on_ping_timeout()
async def _do_send(self) -> None:
if self._write_drain_pending:
await self._writer.drain()
self._write_drain_pending = False
else:
# Never return
await asyncio.get_running_loop().create_future()
def _try_fill_write(self) -> None: def _try_fill_write(self) -> None:
if self._write_drain_pending:
return
for request in self._requests: for request in self._requests:
if isinstance(request.state, NotSerialized): if isinstance(request.state, NotSerialized):
if (msg_id := self._mtp.push(request.body)) is not None: if (msg_id := self._mtp.push(request.body)) is not None:
@ -301,33 +235,40 @@ class Sender:
result = self._mtp.finalize() result = self._mtp.finalize()
if result: if result:
container_msg_id, mtp_buffer = result container_msg_id, mtp_buffer = result
for request in self._requests:
if isinstance(request.state, Serialized):
request.state.container_msg_id = container_msg_id
self._transport.pack(mtp_buffer, self._writer.write) self._transport.pack(mtp_buffer, self._writer.write)
self._write_drain_pending = True for request in self._requests:
if isinstance(request.state, Serialized):
request.state = Sent(request.state.msg_id, container_msg_id)
def _on_net_read(self, read_buffer: bytes) -> None: async def _wait_response(self) -> None:
if not read_buffer: try:
raise ConnectionResetError("read 0 bytes") async with asyncio.timeout(PING_DELAY):
await self._response_state.wait()
except TimeoutError:
self._on_ping_timeout()
self._read_buffer += read_buffer def _on_buffer_updated(self, nbytes: int) -> None:
self._read_buffer_head += nbytes
while self._read_buffer: while self._read_buffer_head:
self._mtp_buffer.clear() self._mtp_buffer.clear()
try: try:
n = self._transport.unpack(self._read_buffer, self._mtp_buffer) n = self._transport.unpack(
memoryview(self._read_buffer)[: self._read_buffer_head],
self._mtp_buffer,
)
except MissingBytesError: except MissingBytesError:
break return
else: else:
del self._read_buffer[:n] del self._read_buffer[:n]
self._read_buffer += bytes(n)
self._read_buffer_head -= n
self._process_mtp_buffer() self._process_mtp_buffer()
def _on_net_write(self) -> None: self._response_state.set()
for req in self._requests:
if isinstance(req.state, Serialized): def _on_conn_closed(self) -> None:
req.state = Sent(req.state.msg_id, req.state.container_msg_id) self._response_state.set()
def _on_ping_timeout(self) -> None: def _on_ping_timeout(self) -> None:
ping_id = generate_random_id() ping_id = generate_random_id()
@ -338,10 +279,9 @@ class Sender:
) )
) )
) )
self._next_ping = asyncio.get_running_loop().time() + PING_DELAY
def _process_mtp_buffer(self) -> None: def _process_mtp_buffer(self) -> None:
results = self._mtp.deserialize(self._mtp_buffer) results = self._mtp.deserialize(bytes(self._mtp_buffer))
for result in results: for result in results:
if isinstance(result, Update): if isinstance(result, Update):
@ -513,5 +453,4 @@ async def generate_auth_key(sender: Sender) -> Sender:
first_salt = finished.first_salt first_salt = finished.first_salt
sender._mtp = Encrypted(auth_key, time_offset=time_offset, first_salt=first_salt) sender._mtp = Encrypted(auth_key, time_offset=time_offset, first_salt=first_salt)
sender._next_ping = asyncio.get_running_loop().time() + PING_DELAY
return sender return sender