Rewrite sender

This commit is contained in:
Jahongir Qurbonov 2024-10-15 17:36:16 +05:00
parent 2ad5e9658d
commit d99e776e2b
2 changed files with 56 additions and 41 deletions

View File

@ -265,7 +265,7 @@ async def invoke_request(
async def step_sender(client: Client) -> None:
try:
assert client._sender
updates = await client._sender.step()
updates = await client._sender.get_updates()
except ConnectionError:
if client.connected:
raise

View File

@ -4,9 +4,10 @@ import struct
import time
from abc import ABC
from asyncio import FIRST_COMPLETED, Event, Future, Lock
from collections import deque
from collections.abc import Iterator
from dataclasses import dataclass
from typing import Generic, Optional, Protocol, Type, TypeVar
from typing import AsyncIterator, Deque, Generic, Optional, Protocol, Type, TypeVar
from typing_extensions import Self
@ -169,11 +170,12 @@ class Sender:
_transport: Transport
_mtp: Mtp
_mtp_buffer: bytearray
_updates: Deque[Updates]
_requests: list[Request[object]]
_request_event: Event
_next_ping: float
_read_buffer: bytearray
_write_drain_pending: bool
_step_iterator: AsyncIterator[None] | None = None
@classmethod
async def connect(
@ -199,11 +201,11 @@ class Sender:
_transport=transport,
_mtp=mtp,
_mtp_buffer=bytearray(),
_updates=deque(),
_requests=[],
_request_event=Event(),
_next_ping=asyncio.get_running_loop().time() + PING_DELAY,
_read_buffer=bytearray(),
_write_drain_pending=False,
)
async def disconnect(self) -> None:
@ -230,54 +232,68 @@ class Sender:
async def _step_until_receive(self, rx: Future[bytes]) -> bytes:
while True:
await self.step()
await self._do_step()
if rx.done():
return rx.result()
async def step(self) -> list[Updates]:
async with self._lock:
return await self._step()
async def get_updates(self) -> list[Updates]:
await self.do_step()
return (
[self._updates.popleft() for _ in range(ulen)]
if (ulen := len(self._updates))
else []
)
async def _step(self) -> list[Updates]:
async def do_step(self) -> None:
async with self._lock:
await self._do_step()
async def _do_step(self) -> None:
await self.step.__anext__()
@property
def step(self) -> AsyncIterator[None]:
if self._step_iterator is None:
self._step_iterator = self._step()
return self._step_iterator
async def _step(self) -> AsyncIterator[None]:
self._try_fill_write()
recv_req = asyncio.create_task(self._request_event.wait())
recv_data = asyncio.create_task(self._reader.read(MAXIMUM_DATA))
send_data = asyncio.create_task(self._do_send())
done, pending = await asyncio.wait(
(recv_req, recv_data, send_data),
timeout=self._next_ping - asyncio.get_running_loop().time(),
return_when=FIRST_COMPLETED,
)
recv_data = asyncio.create_task(self._step_recv())
send_data = asyncio.create_task(self._step_send())
if pending:
for task in pending:
task.cancel()
await asyncio.wait(pending)
while True:
done, _ = await asyncio.wait(
(recv_req, recv_data, send_data),
timeout=self._next_ping - asyncio.get_running_loop().time(),
return_when=FIRST_COMPLETED,
)
result = []
if recv_req in done:
self._request_event.clear()
if recv_data in done:
result = self._on_net_read(recv_data.result())
if send_data in done:
self._on_net_write()
if not done:
self._on_ping_timeout()
return result
yield
async def _do_send(self) -> None:
if self._write_drain_pending:
await self._writer.drain()
self._write_drain_pending = False
else:
# Never return
await asyncio.get_running_loop().create_future()
if recv_req in done:
self._request_event.clear()
recv_req = asyncio.create_task(self._request_event.wait())
if recv_data in done:
recv_data = asyncio.create_task(self._step_recv())
if send_data in done:
send_data = asyncio.create_task(self._step_send())
if not done:
self._on_ping_timeout()
async def _step_recv(self) -> None:
recv_data = await self._reader.read(MAXIMUM_DATA)
updates = self._on_net_read(recv_data)
self._updates.extend(updates)
async def _step_send(self) -> None:
self._try_fill_write()
await self._writer.drain()
self._on_net_write()
def _try_fill_write(self) -> None:
if self._write_drain_pending:
return
for request in self._requests:
if isinstance(request.state, NotSerialized):
if (msg_id := self._mtp.push(request.body)) is not None:
@ -293,7 +309,6 @@ class Sender:
request.state.container_msg_id = container_msg_id
self._transport.pack(mtp_buffer, self._writer.write)
self._write_drain_pending = True
def _on_net_read(self, read_buffer: bytes) -> list[Updates]:
if not read_buffer: