Rewrite sender

This commit is contained in:
Jahongir Qurbonov 2024-10-15 17:36:16 +05:00
parent 2ad5e9658d
commit d99e776e2b
2 changed files with 56 additions and 41 deletions

View File

@ -265,7 +265,7 @@ async def invoke_request(
async def step_sender(client: Client) -> None: async def step_sender(client: Client) -> None:
try: try:
assert client._sender assert client._sender
updates = await client._sender.step() updates = await client._sender.get_updates()
except ConnectionError: except ConnectionError:
if client.connected: if client.connected:
raise raise

View File

@ -4,9 +4,10 @@ import struct
import time import time
from abc import ABC from abc import ABC
from asyncio import FIRST_COMPLETED, Event, Future, Lock from asyncio import FIRST_COMPLETED, Event, Future, Lock
from collections import deque
from collections.abc import Iterator from collections.abc import Iterator
from dataclasses import dataclass from dataclasses import dataclass
from typing import Generic, Optional, Protocol, Type, TypeVar from typing import AsyncIterator, Deque, Generic, Optional, Protocol, Type, TypeVar
from typing_extensions import Self from typing_extensions import Self
@ -169,11 +170,12 @@ class Sender:
_transport: Transport _transport: Transport
_mtp: Mtp _mtp: Mtp
_mtp_buffer: bytearray _mtp_buffer: bytearray
_updates: Deque[Updates]
_requests: list[Request[object]] _requests: list[Request[object]]
_request_event: Event _request_event: Event
_next_ping: float _next_ping: float
_read_buffer: bytearray _read_buffer: bytearray
_write_drain_pending: bool _step_iterator: AsyncIterator[None] | None = None
@classmethod @classmethod
async def connect( async def connect(
@ -199,11 +201,11 @@ class Sender:
_transport=transport, _transport=transport,
_mtp=mtp, _mtp=mtp,
_mtp_buffer=bytearray(), _mtp_buffer=bytearray(),
_updates=deque(),
_requests=[], _requests=[],
_request_event=Event(), _request_event=Event(),
_next_ping=asyncio.get_running_loop().time() + PING_DELAY, _next_ping=asyncio.get_running_loop().time() + PING_DELAY,
_read_buffer=bytearray(), _read_buffer=bytearray(),
_write_drain_pending=False,
) )
async def disconnect(self) -> None: async def disconnect(self) -> None:
@ -230,54 +232,68 @@ class Sender:
async def _step_until_receive(self, rx: Future[bytes]) -> bytes: async def _step_until_receive(self, rx: Future[bytes]) -> bytes:
while True: while True:
await self.step() await self._do_step()
if rx.done(): if rx.done():
return rx.result() return rx.result()
async def step(self) -> list[Updates]: async def get_updates(self) -> list[Updates]:
async with self._lock: await self.do_step()
return await self._step() return (
[self._updates.popleft() for _ in range(ulen)]
if (ulen := len(self._updates))
else []
)
async def _step(self) -> list[Updates]: async def do_step(self) -> None:
async with self._lock:
await self._do_step()
async def _do_step(self) -> None:
await self.step.__anext__()
@property
def step(self) -> AsyncIterator[None]:
if self._step_iterator is None:
self._step_iterator = self._step()
return self._step_iterator
async def _step(self) -> AsyncIterator[None]:
self._try_fill_write() self._try_fill_write()
recv_req = asyncio.create_task(self._request_event.wait()) recv_req = asyncio.create_task(self._request_event.wait())
recv_data = asyncio.create_task(self._reader.read(MAXIMUM_DATA)) recv_data = asyncio.create_task(self._step_recv())
send_data = asyncio.create_task(self._do_send()) send_data = asyncio.create_task(self._step_send())
done, pending = await asyncio.wait(
(recv_req, recv_data, send_data),
timeout=self._next_ping - asyncio.get_running_loop().time(),
return_when=FIRST_COMPLETED,
)
if pending: while True:
for task in pending: done, _ = await asyncio.wait(
task.cancel() (recv_req, recv_data, send_data),
await asyncio.wait(pending) timeout=self._next_ping - asyncio.get_running_loop().time(),
return_when=FIRST_COMPLETED,
)
result = [] yield
if recv_req in done:
self._request_event.clear()
if recv_data in done:
result = self._on_net_read(recv_data.result())
if send_data in done:
self._on_net_write()
if not done:
self._on_ping_timeout()
return result
async def _do_send(self) -> None: if recv_req in done:
if self._write_drain_pending: self._request_event.clear()
await self._writer.drain() recv_req = asyncio.create_task(self._request_event.wait())
self._write_drain_pending = False if recv_data in done:
else: recv_data = asyncio.create_task(self._step_recv())
# Never return if send_data in done:
await asyncio.get_running_loop().create_future() send_data = asyncio.create_task(self._step_send())
if not done:
self._on_ping_timeout()
async def _step_recv(self) -> None:
recv_data = await self._reader.read(MAXIMUM_DATA)
updates = self._on_net_read(recv_data)
self._updates.extend(updates)
async def _step_send(self) -> None:
self._try_fill_write()
await self._writer.drain()
self._on_net_write()
def _try_fill_write(self) -> None: def _try_fill_write(self) -> None:
if self._write_drain_pending:
return
for request in self._requests: for request in self._requests:
if isinstance(request.state, NotSerialized): if isinstance(request.state, NotSerialized):
if (msg_id := self._mtp.push(request.body)) is not None: if (msg_id := self._mtp.push(request.body)) is not None:
@ -293,7 +309,6 @@ class Sender:
request.state.container_msg_id = container_msg_id request.state.container_msg_id = container_msg_id
self._transport.pack(mtp_buffer, self._writer.write) self._transport.pack(mtp_buffer, self._writer.write)
self._write_drain_pending = True
def _on_net_read(self, read_buffer: bytes) -> list[Updates]: def _on_net_read(self, read_buffer: bytes) -> list[Updates]:
if not read_buffer: if not read_buffer: