mirror of
https://github.com/LonamiWebs/Telethon.git
synced 2025-08-06 13:10:22 +03:00
Improve step runner
This commit is contained in:
parent
88b332a498
commit
ca97078855
|
@ -61,7 +61,7 @@ async def handle_migrate(client: Client, dc_id: Optional[int]) -> None:
|
||||||
client._config, client._session.dcs, DataCenter(id=dc_id)
|
client._config, client._session.dcs, DataCenter(id=dc_id)
|
||||||
)
|
)
|
||||||
|
|
||||||
async with client._sender._lock:
|
async with client._sender._step_lock:
|
||||||
old_sender = client._sender
|
old_sender = client._sender
|
||||||
client._sender = sender
|
client._sender = sender
|
||||||
await old_sender.disconnect()
|
await old_sender.disconnect()
|
||||||
|
|
|
@ -4,13 +4,15 @@ import struct
|
||||||
import sys
|
import sys
|
||||||
import time
|
import time
|
||||||
from abc import ABC
|
from abc import ABC
|
||||||
from asyncio import FIRST_COMPLETED, Event, Future, Lock
|
from asyncio import FIRST_COMPLETED, Event, Future, Lock, Task
|
||||||
from collections.abc import AsyncGenerator, Iterator
|
from collections.abc import AsyncGenerator, Iterator
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import (
|
from typing import (
|
||||||
Generic,
|
Generic,
|
||||||
|
Literal,
|
||||||
Optional,
|
Optional,
|
||||||
Protocol,
|
Protocol,
|
||||||
|
Set,
|
||||||
Type,
|
Type,
|
||||||
TypeVar,
|
TypeVar,
|
||||||
)
|
)
|
||||||
|
@ -178,7 +180,6 @@ class Sender:
|
||||||
dc_id: int
|
dc_id: int
|
||||||
addr: str
|
addr: str
|
||||||
_logger: logging.Logger
|
_logger: logging.Logger
|
||||||
_lock: Lock
|
|
||||||
_reader: AsyncReader
|
_reader: AsyncReader
|
||||||
_writer: AsyncWriter
|
_writer: AsyncWriter
|
||||||
_transport: Transport
|
_transport: Transport
|
||||||
|
@ -189,6 +190,7 @@ class Sender:
|
||||||
_request_event: Event
|
_request_event: Event
|
||||||
_next_ping: float
|
_next_ping: float
|
||||||
_read_buffer: bytearray
|
_read_buffer: bytearray
|
||||||
|
_step_lock: Lock
|
||||||
_step_generator: AsyncGenerator[None, None] | None = None
|
_step_generator: AsyncGenerator[None, None] | None = None
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
@ -209,7 +211,6 @@ class Sender:
|
||||||
dc_id=dc_id,
|
dc_id=dc_id,
|
||||||
addr=addr,
|
addr=addr,
|
||||||
_logger=base_logger.getChild("mtsender"),
|
_logger=base_logger.getChild("mtsender"),
|
||||||
_lock=Lock(),
|
|
||||||
_reader=reader,
|
_reader=reader,
|
||||||
_writer=writer,
|
_writer=writer,
|
||||||
_transport=transport,
|
_transport=transport,
|
||||||
|
@ -220,6 +221,7 @@ class Sender:
|
||||||
_request_event=Event(),
|
_request_event=Event(),
|
||||||
_next_ping=asyncio.get_running_loop().time() + PING_DELAY,
|
_next_ping=asyncio.get_running_loop().time() + PING_DELAY,
|
||||||
_read_buffer=bytearray(),
|
_read_buffer=bytearray(),
|
||||||
|
_step_lock=Lock(),
|
||||||
)
|
)
|
||||||
|
|
||||||
async def disconnect(self) -> None:
|
async def disconnect(self) -> None:
|
||||||
|
@ -229,7 +231,6 @@ class Sender:
|
||||||
|
|
||||||
def enqueue(self, request: RemoteCall[Return]) -> Future[bytes]:
|
def enqueue(self, request: RemoteCall[Return]) -> Future[bytes]:
|
||||||
rx = self._enqueue_body(bytes(request))
|
rx = self._enqueue_body(bytes(request))
|
||||||
self._request_event.set()
|
|
||||||
return rx
|
return rx
|
||||||
|
|
||||||
async def invoke(self, request: RemoteCall[Return]) -> bytes:
|
async def invoke(self, request: RemoteCall[Return]) -> bytes:
|
||||||
|
@ -243,11 +244,14 @@ class Sender:
|
||||||
def _enqueue_body(self, body: bytes) -> Future[bytes]:
|
def _enqueue_body(self, body: bytes) -> Future[bytes]:
|
||||||
oneshot = asyncio.get_running_loop().create_future()
|
oneshot = asyncio.get_running_loop().create_future()
|
||||||
self._requests.append(Request(body=body, state=NotSerialized(), result=oneshot))
|
self._requests.append(Request(body=body, state=NotSerialized(), result=oneshot))
|
||||||
|
self._request_event.set()
|
||||||
return oneshot
|
return oneshot
|
||||||
|
|
||||||
async def _step_until_receive(self, rx: Future[bytes]) -> bytes:
|
async def _step_until_receive(self, rx: Future[bytes]) -> bytes:
|
||||||
while True:
|
while True:
|
||||||
await self._do_step()
|
step = asyncio.create_task(self.do_step())
|
||||||
|
await asyncio.wait((step, rx), return_when=FIRST_COMPLETED)
|
||||||
|
|
||||||
if rx.done():
|
if rx.done():
|
||||||
return rx.result()
|
return rx.result()
|
||||||
|
|
||||||
|
@ -258,11 +262,8 @@ class Sender:
|
||||||
return result
|
return result
|
||||||
|
|
||||||
async def do_step(self) -> None:
|
async def do_step(self) -> None:
|
||||||
async with self._lock:
|
async with self._step_lock:
|
||||||
await self._do_step()
|
await anext(self.step)
|
||||||
|
|
||||||
async def _do_step(self) -> None:
|
|
||||||
await anext(self.step)
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def step(self) -> AsyncGenerator[None, None]:
|
def step(self) -> AsyncGenerator[None, None]:
|
||||||
|
@ -275,42 +276,52 @@ class Sender:
|
||||||
recv_data = asyncio.create_task(self._step_recv())
|
recv_data = asyncio.create_task(self._step_recv())
|
||||||
send_data = asyncio.create_task(self._step_send())
|
send_data = asyncio.create_task(self._step_send())
|
||||||
|
|
||||||
|
pending: Set[Task[Literal[True] | None]] = set()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
while True:
|
while True:
|
||||||
done, pending = await asyncio.wait(
|
_, pending = await asyncio.wait(
|
||||||
(recv_req, recv_data, send_data),
|
(recv_data, send_data),
|
||||||
timeout=self._next_ping - asyncio.get_running_loop().time(),
|
timeout=self._next_ping - asyncio.get_running_loop().time(),
|
||||||
return_when=FIRST_COMPLETED,
|
return_when=FIRST_COMPLETED,
|
||||||
)
|
) # pyright: ignore [reportAssignmentType]
|
||||||
|
|
||||||
yield
|
yield
|
||||||
|
|
||||||
if recv_req in done:
|
if recv_req.done():
|
||||||
self._request_event.clear()
|
|
||||||
recv_req = asyncio.create_task(self._request_event.wait())
|
recv_req = asyncio.create_task(self._request_event.wait())
|
||||||
if recv_data in done:
|
if recv_data.done():
|
||||||
recv_data = asyncio.create_task(self._step_recv())
|
recv_data = asyncio.create_task(self._step_recv())
|
||||||
if send_data in done:
|
if send_data.done():
|
||||||
send_data = asyncio.create_task(self._step_send())
|
send_data = asyncio.create_task(self._step_send())
|
||||||
if not done:
|
|
||||||
self._on_ping_timeout()
|
|
||||||
finally:
|
finally:
|
||||||
if pending:
|
await self._try_cancel_tasks(pending)
|
||||||
for task in pending:
|
|
||||||
task.cancel()
|
async def _try_cancel_tasks(self, pending: set[Task]) -> None:
|
||||||
await asyncio.wait(pending)
|
if pending:
|
||||||
|
for task in pending:
|
||||||
|
task.cancel()
|
||||||
|
await asyncio.wait(pending)
|
||||||
|
|
||||||
async def _step_recv(self) -> None:
|
async def _step_recv(self) -> None:
|
||||||
recv_data = await self._reader.read(MAXIMUM_DATA)
|
try:
|
||||||
result = self._on_net_read(recv_data)
|
async with asyncio.timeout(PING_DELAY):
|
||||||
self._updates.extend(result)
|
recv_data = await self._reader.read(MAXIMUM_DATA)
|
||||||
|
result = self._on_net_read(recv_data)
|
||||||
|
self._updates.extend(result)
|
||||||
|
except TimeoutError:
|
||||||
|
self._on_ping_timeout()
|
||||||
|
|
||||||
async def _step_send(self) -> None:
|
async def _step_send(self) -> None:
|
||||||
self._try_fill_write()
|
await self._request_event.wait()
|
||||||
|
await self._try_fill_write()
|
||||||
await self._writer.drain()
|
await self._writer.drain()
|
||||||
self._on_net_write()
|
self._on_net_write()
|
||||||
|
|
||||||
def _try_fill_write(self) -> None:
|
if not self._requests:
|
||||||
|
self._request_event.clear()
|
||||||
|
|
||||||
|
async def _try_fill_write(self) -> None:
|
||||||
for request in self._requests:
|
for request in self._requests:
|
||||||
if isinstance(request.state, NotSerialized):
|
if isinstance(request.state, NotSerialized):
|
||||||
if (msg_id := self._mtp.push(request.body)) is not None:
|
if (msg_id := self._mtp.push(request.body)) is not None:
|
||||||
|
|
Loading…
Reference in New Issue
Block a user