Refactor sender

This commit is contained in:
Jahongir Qurbonov 2024-10-22 15:45:44 +05:00
parent e4e7681051
commit c138e17f4a

View File

@ -3,7 +3,7 @@ import logging
import struct import struct
import time import time
from abc import ABC from abc import ABC
from asyncio import FIRST_COMPLETED, Event, Future, Lock from asyncio import Event, Future, Lock
from collections.abc import Iterator from collections.abc import Iterator
from dataclasses import dataclass from dataclasses import dataclass
from typing import Generic, Optional, Protocol, Type, TypeVar from typing import Generic, Optional, Protocol, Type, TypeVar
@ -166,15 +166,15 @@ class Sender:
_logger: logging.Logger _logger: logging.Logger
_reader: AsyncReader _reader: AsyncReader
_writer: AsyncWriter _writer: AsyncWriter
_reading: bool
_writing: bool
_transport: Transport _transport: Transport
_mtp: Mtp _mtp: Mtp
_mtp_buffer: bytearray _mtp_buffer: bytearray
_updates: list[Updates] _updates: list[Updates]
_requests: list[Request[object]] _requests: list[Request[object]]
_request_event: Event _response_event: Event
_next_ping: float
_read_buffer: bytearray _read_buffer: bytearray
_write_drain_pending: bool
_step_counter: int _step_counter: int
@classmethod @classmethod
@ -196,6 +196,8 @@ class Sender:
addr=addr, addr=addr,
lock=Lock(), lock=Lock(),
_logger=base_logger.getChild("mtsender"), _logger=base_logger.getChild("mtsender"),
_reading=False,
_writing=False,
_reader=reader, _reader=reader,
_writer=writer, _writer=writer,
_transport=transport, _transport=transport,
@ -203,10 +205,8 @@ class Sender:
_mtp_buffer=bytearray(), _mtp_buffer=bytearray(),
_updates=[], _updates=[],
_requests=[], _requests=[],
_request_event=Event(), _response_event=Event(),
_next_ping=asyncio.get_running_loop().time() + PING_DELAY,
_read_buffer=bytearray(), _read_buffer=bytearray(),
_write_drain_pending=False,
_step_counter=0, _step_counter=0,
) )
@ -216,7 +216,6 @@ class Sender:
def enqueue(self, request: RemoteCall[Return]) -> Future[bytes]: def enqueue(self, request: RemoteCall[Return]) -> Future[bytes]:
rx = self._enqueue_body(bytes(request)) rx = self._enqueue_body(bytes(request))
self._request_event.set()
return rx return rx
async def invoke(self, request: RemoteCall[Return]) -> bytes: async def invoke(self, request: RemoteCall[Return]) -> bytes:
@ -239,58 +238,34 @@ class Sender:
return rx.result() return rx.result()
async def step(self) -> None: async def step(self) -> None:
ticket_number = self._step_counter if not self._writing:
self._writing = True
self._try_fill_write()
self._writing = False
async with self.lock: if not self._reading:
if self._step_counter == ticket_number: self._reading = True
# We're the one to drive IO. self._response_event.clear()
self._step_counter += 1 await self._try_read()
await self._step() self._reading = False
# else: different task drove IO. else:
await self._response_event.wait()
def pop_updates(self) -> list[Updates]: def pop_updates(self) -> list[Updates]:
updates = self._updates[:] updates = self._updates[:]
self._updates.clear() self._updates.clear()
return updates return updates
async def _step(self) -> None: async def _try_read(self) -> None:
self._try_fill_write() try:
async with asyncio.timeout(PING_DELAY):
recv_req = asyncio.create_task(self._request_event.wait()) recv_data = await self._reader.read(MAXIMUM_DATA)
recv_data = asyncio.create_task(self._reader.read(MAXIMUM_DATA)) except TimeoutError:
send_data = asyncio.create_task(self._do_send())
done, pending = await asyncio.wait(
(recv_req, recv_data, send_data),
timeout=self._next_ping - asyncio.get_running_loop().time(),
return_when=FIRST_COMPLETED,
)
if pending:
for task in pending:
task.cancel()
await asyncio.wait(pending)
if recv_req in done:
self._request_event.clear()
if recv_data in done:
self._on_net_read(recv_data.result())
if send_data in done:
self._on_net_write()
if not done:
self._on_ping_timeout() self._on_ping_timeout()
async def _do_send(self) -> None:
if self._write_drain_pending:
await self._writer.drain()
self._write_drain_pending = False
else: else:
# Never return self._on_net_read(recv_data)
await asyncio.get_running_loop().create_future()
def _try_fill_write(self) -> None: def _try_fill_write(self) -> None:
if self._write_drain_pending:
return
for request in self._requests: for request in self._requests:
if isinstance(request.state, NotSerialized): if isinstance(request.state, NotSerialized):
if (msg_id := self._mtp.push(request.body)) is not None: if (msg_id := self._mtp.push(request.body)) is not None:
@ -301,12 +276,11 @@ class Sender:
result = self._mtp.finalize() result = self._mtp.finalize()
if result: if result:
container_msg_id, mtp_buffer = result container_msg_id, mtp_buffer = result
for request in self._requests:
if isinstance(request.state, Serialized):
request.state.container_msg_id = container_msg_id
self._transport.pack(mtp_buffer, self._writer.write) self._transport.pack(mtp_buffer, self._writer.write)
self._write_drain_pending = True for request in self._requests:
if isinstance(request.state, Serialized):
request.state = Sent(request.state.msg_id, container_msg_id)
def _on_net_read(self, read_buffer: bytes) -> None: def _on_net_read(self, read_buffer: bytes) -> None:
if not read_buffer: if not read_buffer:
@ -323,11 +297,7 @@ class Sender:
else: else:
del self._read_buffer[:n] del self._read_buffer[:n]
self._process_mtp_buffer() self._process_mtp_buffer()
self._response_event.set()
def _on_net_write(self) -> None:
for req in self._requests:
if isinstance(req.state, Serialized):
req.state = Sent(req.state.msg_id, req.state.container_msg_id)
def _on_ping_timeout(self) -> None: def _on_ping_timeout(self) -> None:
ping_id = generate_random_id() ping_id = generate_random_id()
@ -338,7 +308,6 @@ class Sender:
) )
) )
) )
self._next_ping = asyncio.get_running_loop().time() + PING_DELAY
def _process_mtp_buffer(self) -> None: def _process_mtp_buffer(self) -> None:
results = self._mtp.deserialize(self._mtp_buffer) results = self._mtp.deserialize(self._mtp_buffer)
@ -513,5 +482,4 @@ async def generate_auth_key(sender: Sender) -> Sender:
first_salt = finished.first_salt first_salt = finished.first_salt
sender._mtp = Encrypted(auth_key, time_offset=time_offset, first_salt=first_salt) sender._mtp = Encrypted(auth_key, time_offset=time_offset, first_salt=first_salt)
sender._next_ping = asyncio.get_running_loop().time() + PING_DELAY
return sender return sender