Improve sender

This commit is contained in:
Jahongir Qurbonov 2024-10-17 22:53:55 +05:00
parent f5838961dd
commit 6dde97237f
2 changed files with 42 additions and 49 deletions

View File

@ -171,6 +171,10 @@ async def connect(self: Client, reconnect: bool = False) -> None:
if self._sender and not reconnect: if self._sender and not reconnect:
return return
if reconnect:
assert self._sender
await self._sender.disconnect()
if session := await self._storage.load(): if session := await self._storage.load():
self._session = session self._session = session

View File

@ -3,7 +3,7 @@ import logging
import struct import struct
import time import time
from abc import ABC from abc import ABC
from asyncio import FIRST_COMPLETED, Event, Future, Task from asyncio import FIRST_COMPLETED, Event, Future, Lock, Task
from collections.abc import Iterator from collections.abc import Iterator
from dataclasses import dataclass from dataclasses import dataclass
from typing import Generic, Optional, Protocol, Type, TypeVar from typing import Generic, Optional, Protocol, Type, TypeVar
@ -172,8 +172,8 @@ class Sender:
_requests: list[Request[object]] _requests: list[Request[object]]
_request_event: Event _request_event: Event
_read_buffer: bytearray _read_buffer: bytearray
_write_drain_pending: bool _step_lock: Lock
_step_event: Event _step_counter: int
_recv_task: Optional[Task[bytes]] = None _recv_task: Optional[Task[bytes]] = None
_send_task: Optional[Task[None]] = None _send_task: Optional[Task[None]] = None
@ -191,7 +191,7 @@ class Sender:
ip, port = addr.split(":") ip, port = addr.split(":")
reader, writer = await connector(ip, int(port)) reader, writer = await connector(ip, int(port))
sender = cls( return cls(
dc_id=dc_id, dc_id=dc_id,
addr=addr, addr=addr,
_logger=base_logger.getChild("mtsender"), _logger=base_logger.getChild("mtsender"),
@ -204,28 +204,15 @@ class Sender:
_requests=[], _requests=[],
_request_event=Event(), _request_event=Event(),
_read_buffer=bytearray(), _read_buffer=bytearray(),
_write_drain_pending=False, _step_lock=Lock(),
_step_event=Event(), _step_counter=0,
) )
sender._recv_task = asyncio.create_task(sender._do_recv())
sender._send_task = asyncio.create_task(sender._do_send())
return sender
async def disconnect(self) -> None: async def disconnect(self) -> None:
assert self._recv_task and self._send_task assert self._recv_task
assert self._send_task
( recv_task, send_task = self._recv_task, self._send_task
recv_task, self._recv_task, self._send_task = None, None
send_task,
self._recv_task,
self._send_task,
) = (
self._recv_task,
self._send_task,
None,
None,
)
recv_task.cancel() recv_task.cancel()
send_task.cancel() send_task.cancel()
@ -253,10 +240,9 @@ class Sender:
return oneshot return oneshot
async def _step_until_receive(self, rx: Future[bytes]) -> bytes: async def _step_until_receive(self, rx: Future[bytes]) -> bytes:
while True: while not rx.done():
await self.step() await self.step()
if rx.done(): return rx.result()
return rx.result()
async def step_updates(self) -> list[Updates]: async def step_updates(self) -> list[Updates]:
await self.step() await self.step()
@ -264,11 +250,22 @@ class Sender:
return updates return updates
async def step(self) -> None: async def step(self) -> None:
if self._recv_task is None or self._send_task is None: ticket_number = self._step_counter
return async with self._step_lock:
if self._step_counter == ticket_number:
# We're the one to drive IO.
await self._step()
self._step_counter += 1
self._step_event.clear() async def _step(self) -> None:
self._try_fill_write() if self._step_counter == 0:
self._try_fill_write()
self._recv_task = asyncio.create_task(self._do_recv())
self._send_task = asyncio.create_task(self._do_send())
if self._recv_task is None or self._send_task is None:
# Disconnected
return
await asyncio.wait( await asyncio.wait(
(self._recv_task, self._send_task), return_when=FIRST_COMPLETED (self._recv_task, self._send_task), return_when=FIRST_COMPLETED
@ -276,7 +273,7 @@ class Sender:
if self._recv_task.done(): if self._recv_task.done():
try: try:
buff = await self._recv_task buff = self._recv_task.result()
except TimeoutError: except TimeoutError:
self._on_ping_timeout() self._on_ping_timeout()
else: else:
@ -284,31 +281,24 @@ class Sender:
self._recv_task = asyncio.create_task(self._do_recv()) self._recv_task = asyncio.create_task(self._do_recv())
if self._send_task.done(): if self._send_task.done():
await self._send_task self._send_task.result()
self._on_net_write() self._on_net_write()
self._try_fill_write()
self._send_task = asyncio.create_task(self._do_send()) self._send_task = asyncio.create_task(self._do_send())
await self._step_event.wait()
async def _do_recv(self) -> bytes: async def _do_recv(self) -> bytes:
try: async with asyncio.timeout(PING_DELAY):
async with asyncio.timeout(PING_DELAY): return await self._reader.read(MAXIMUM_DATA)
return await self._reader.read(MAXIMUM_DATA)
finally:
self._step_event.set()
async def _do_send(self) -> None: async def _do_send(self) -> None:
try: await self._request_event.wait()
if self._write_drain_pending: await self._writer.drain()
await self._writer.drain() if not self._requests:
self._write_drain_pending = False self._request_event.clear()
else:
await self._request_event.wait()
finally:
self._step_event.set()
def _try_fill_write(self) -> None: def _try_fill_write(self) -> None:
if self._write_drain_pending: if not self._requests:
return return
for request in self._requests: for request in self._requests:
@ -326,7 +316,6 @@ class Sender:
request.state.container_msg_id = container_msg_id request.state.container_msg_id = container_msg_id
self._transport.pack(mtp_buffer, self._writer.write) self._transport.pack(mtp_buffer, self._writer.write)
self._write_drain_pending = True
def _on_net_read(self, read_buffer: bytes) -> None: def _on_net_read(self, read_buffer: bytes) -> None:
if not read_buffer: if not read_buffer: