Refactor sender to avoid cancelling tasks (#4497)

This commit is contained in:
Jahongir Qurbonov 2024-10-24 22:00:00 +05:00 committed by GitHub
parent e4e7681051
commit f15af530fa
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 61 additions and 75 deletions

View File

@ -61,7 +61,6 @@ async def handle_migrate(client: Client, dc_id: Optional[int]) -> None:
client._config, client._session.dcs, DataCenter(id=dc_id) client._config, client._session.dcs, DataCenter(id=dc_id)
) )
async with client._sender.lock:
old_sender = client._sender old_sender = client._sender
client._sender = sender client._sender = sender
await old_sender.disconnect() await old_sender.disconnect()

View File

@ -3,7 +3,7 @@ import logging
import struct import struct
import time import time
from abc import ABC from abc import ABC
from asyncio import FIRST_COMPLETED, Event, Future, Lock from asyncio import Event, Future
from collections.abc import Iterator from collections.abc import Iterator
from dataclasses import dataclass from dataclasses import dataclass
from typing import Generic, Optional, Protocol, Type, TypeVar from typing import Generic, Optional, Protocol, Type, TypeVar
@ -162,20 +162,19 @@ class Request(Generic[Return]):
class Sender: class Sender:
dc_id: int dc_id: int
addr: str addr: str
lock: Lock
_logger: logging.Logger _logger: logging.Logger
_reader: AsyncReader _reader: AsyncReader
_writer: AsyncWriter _writer: AsyncWriter
_reading: bool
_writing: bool
_step_done: Event
_transport: Transport _transport: Transport
_mtp: Mtp _mtp: Mtp
_mtp_buffer: bytearray _mtp_buffer: bytearray
_updates: list[Updates] _updates: list[Updates]
_requests: list[Request[object]] _requests: list[Request[object]]
_request_event: Event
_next_ping: float _next_ping: float
_read_buffer: bytearray _read_buffer: bytearray
_write_drain_pending: bool
_step_counter: int
@classmethod @classmethod
async def connect( async def connect(
@ -194,29 +193,28 @@ class Sender:
return cls( return cls(
dc_id=dc_id, dc_id=dc_id,
addr=addr, addr=addr,
lock=Lock(),
_logger=base_logger.getChild("mtsender"), _logger=base_logger.getChild("mtsender"),
_reader=reader, _reader=reader,
_writer=writer, _writer=writer,
_reading=False,
_writing=False,
_step_done=Event(),
_transport=transport, _transport=transport,
_mtp=mtp, _mtp=mtp,
_mtp_buffer=bytearray(), _mtp_buffer=bytearray(),
_updates=[], _updates=[],
_requests=[], _requests=[],
_request_event=Event(),
_next_ping=asyncio.get_running_loop().time() + PING_DELAY, _next_ping=asyncio.get_running_loop().time() + PING_DELAY,
_read_buffer=bytearray(), _read_buffer=bytearray(),
_write_drain_pending=False,
_step_counter=0,
) )
async def disconnect(self) -> None: async def disconnect(self) -> None:
self._writer.close() self._writer.close()
await self._writer.wait_closed() await self._writer.wait_closed()
self._step_done.set()
def enqueue(self, request: RemoteCall[Return]) -> Future[bytes]: def enqueue(self, request: RemoteCall[Return]) -> Future[bytes]:
rx = self._enqueue_body(bytes(request)) rx = self._enqueue_body(bytes(request))
self._request_event.set()
return rx return rx
async def invoke(self, request: RemoteCall[Return]) -> bytes: async def invoke(self, request: RemoteCall[Return]) -> bytes:
@ -239,56 +237,47 @@ class Sender:
return rx.result() return rx.result()
async def step(self) -> None: async def step(self) -> None:
ticket_number = self._step_counter if not self._writing:
self._writing = True
await self._do_write()
self._writing = False
async with self.lock: if not self._reading:
if self._step_counter == ticket_number: self._reading = True
# We're the one to drive IO. await self._do_read()
self._step_counter += 1 self._reading = False
await self._step()
# else: different task drove IO. if not self._step_done.is_set():
await self._step_done.wait()
def pop_updates(self) -> list[Updates]: def pop_updates(self) -> list[Updates]:
updates = self._updates[:] updates = self._updates[:]
self._updates.clear() self._updates.clear()
return updates return updates
async def _step(self) -> None: async def _do_read(self) -> None:
self._try_fill_write() self._step_done.clear()
recv_req = asyncio.create_task(self._request_event.wait()) timeout = self._next_ping - asyncio.get_running_loop().time()
recv_data = asyncio.create_task(self._reader.read(MAXIMUM_DATA)) try:
send_data = asyncio.create_task(self._do_send()) async with asyncio.timeout(timeout):
done, pending = await asyncio.wait( recv_data = await self._reader.read(MAXIMUM_DATA)
(recv_req, recv_data, send_data), except TimeoutError:
timeout=self._next_ping - asyncio.get_running_loop().time(), pass
return_when=FIRST_COMPLETED,
)
if pending:
for task in pending:
task.cancel()
await asyncio.wait(pending)
if recv_req in done:
self._request_event.clear()
if recv_data in done:
self._on_net_read(recv_data.result())
if send_data in done:
self._on_net_write()
if not done:
self._on_ping_timeout()
async def _do_send(self) -> None:
if self._write_drain_pending:
await self._writer.drain()
self._write_drain_pending = False
else: else:
# Never return self._on_net_read(recv_data)
await asyncio.get_running_loop().create_future() finally:
self._try_timeout_ping()
self._step_done.set()
def _try_fill_write(self) -> None: async def _do_write(self) -> None:
if self._write_drain_pending: self._step_done.clear()
await self._try_fill_write()
self._try_timeout_ping()
self._step_done.set()
async def _try_fill_write(self) -> None:
if not self._requests:
return return
for request in self._requests: for request in self._requests:
@ -301,12 +290,27 @@ class Sender:
result = self._mtp.finalize() result = self._mtp.finalize()
if result: if result:
container_msg_id, mtp_buffer = result container_msg_id, mtp_buffer = result
for request in self._requests:
if isinstance(request.state, Serialized):
request.state.container_msg_id = container_msg_id
self._transport.pack(mtp_buffer, self._writer.write) self._transport.pack(mtp_buffer, self._writer.write)
self._write_drain_pending = True await self._writer.drain()
for request in self._requests:
if isinstance(request.state, Serialized):
request.state = Sent(request.state.msg_id, container_msg_id)
def _try_timeout_ping(self) -> None:
current_time = asyncio.get_running_loop().time()
if current_time >= self._next_ping:
ping_id = generate_random_id()
self._enqueue_body(
bytes(
ping_delay_disconnect(
ping_id=ping_id, disconnect_delay=NO_PING_DISCONNECT
)
)
)
self._next_ping = current_time + PING_DELAY
def _on_net_read(self, read_buffer: bytes) -> None: def _on_net_read(self, read_buffer: bytes) -> None:
if not read_buffer: if not read_buffer:
@ -324,22 +328,6 @@ class Sender:
del self._read_buffer[:n] del self._read_buffer[:n]
self._process_mtp_buffer() self._process_mtp_buffer()
def _on_net_write(self) -> None:
for req in self._requests:
if isinstance(req.state, Serialized):
req.state = Sent(req.state.msg_id, req.state.container_msg_id)
def _on_ping_timeout(self) -> None:
ping_id = generate_random_id()
self._enqueue_body(
bytes(
ping_delay_disconnect(
ping_id=ping_id, disconnect_delay=NO_PING_DISCONNECT
)
)
)
self._next_ping = asyncio.get_running_loop().time() + PING_DELAY
def _process_mtp_buffer(self) -> None: def _process_mtp_buffer(self) -> None:
results = self._mtp.deserialize(self._mtp_buffer) results = self._mtp.deserialize(self._mtp_buffer)
@ -513,5 +501,4 @@ async def generate_auth_key(sender: Sender) -> Sender:
first_salt = finished.first_salt first_salt = finished.first_salt
sender._mtp = Encrypted(auth_key, time_offset=time_offset, first_salt=first_salt) sender._mtp = Encrypted(auth_key, time_offset=time_offset, first_salt=first_salt)
sender._next_ping = asyncio.get_running_loop().time() + PING_DELAY
return sender return sender