diff --git a/client/src/telethon/_impl/mtsender/sender.py b/client/src/telethon/_impl/mtsender/sender.py index 46c89d9c..3c50ab8a 100644 --- a/client/src/telethon/_impl/mtsender/sender.py +++ b/client/src/telethon/_impl/mtsender/sender.py @@ -1,13 +1,19 @@ import asyncio import logging import struct +import sys import time from abc import ABC from asyncio import FIRST_COMPLETED, Event, Future, Lock -from collections import deque -from collections.abc import Iterator +from collections.abc import AsyncGenerator, Iterator from dataclasses import dataclass -from typing import AsyncIterator, Deque, Generic, Optional, Protocol, Type, TypeVar +from typing import ( + Generic, + Optional, + Protocol, + Type, + TypeVar, +) from typing_extensions import Self @@ -32,6 +38,14 @@ from ..tl.mtproto.functions import ping_delay_disconnect from ..tl.types import UpdateDeleteMessages, UpdateShort from ..tl.types.messages import AffectedFoundMessages, AffectedHistory, AffectedMessages +if sys.version_info < (3, 10): + Y = TypeVar("Y") + S = TypeVar("S") + + async def anext(it: AsyncGenerator[Y, S]) -> Y: + return await it.__anext__() + + MAXIMUM_DATA = (1024 * 1024) + (8 * 1024) PING_DELAY = 60 @@ -170,12 +184,12 @@ class Sender: _transport: Transport _mtp: Mtp _mtp_buffer: bytearray - _updates: Deque[Updates] + _updates: list[Updates] _requests: list[Request[object]] _request_event: Event _next_ping: float _read_buffer: bytearray - _step_iterator: AsyncIterator[None] | None = None + _step_generator: AsyncGenerator[None, None] | None = None @classmethod async def connect( @@ -201,7 +215,7 @@ class Sender: _transport=transport, _mtp=mtp, _mtp_buffer=bytearray(), - _updates=deque(), + _updates=[], _requests=[], _request_event=Event(), _next_ping=asyncio.get_running_loop().time() + PING_DELAY, @@ -209,6 +223,7 @@ class Sender: ) async def disconnect(self) -> None: + await self.step.aclose() self._writer.close() await self._writer.wait_closed() @@ -238,55 +253,59 @@ class Sender: async def get_updates(self) -> list[Updates]: await self.do_step() - return ( - [self._updates.popleft() for _ in range(ulen)] - if (ulen := len(self._updates)) - else [] - ) + result = self._updates.copy() + self._updates.clear() + return result async def do_step(self) -> None: async with self._lock: await self._do_step() async def _do_step(self) -> None: - await self.step.__anext__() + await anext(self.step) @property - def step(self) -> AsyncIterator[None]: - if self._step_iterator is None: - self._step_iterator = self._step() - return self._step_iterator + def step(self) -> AsyncGenerator[None, None]: + if self._step_generator is None: + self._step_generator = self._step() + return self._step_generator - async def _step(self) -> AsyncIterator[None]: + async def _step(self) -> AsyncGenerator[None, None]: self._try_fill_write() recv_req = asyncio.create_task(self._request_event.wait()) recv_data = asyncio.create_task(self._step_recv()) send_data = asyncio.create_task(self._step_send()) - while True: - done, _ = await asyncio.wait( - (recv_req, recv_data, send_data), - timeout=self._next_ping - asyncio.get_running_loop().time(), - return_when=FIRST_COMPLETED, - ) + try: + while True: + done, pending = await asyncio.wait( + (recv_req, recv_data, send_data), + timeout=self._next_ping - asyncio.get_running_loop().time(), + return_when=FIRST_COMPLETED, + ) - yield + yield - if recv_req in done: - self._request_event.clear() - recv_req = asyncio.create_task(self._request_event.wait()) - if recv_data in done: - recv_data = asyncio.create_task(self._step_recv()) - if send_data in done: - send_data = asyncio.create_task(self._step_send()) - if not done: - self._on_ping_timeout() + if recv_req in done: + self._request_event.clear() + recv_req = asyncio.create_task(self._request_event.wait()) + if recv_data in done: + recv_data = asyncio.create_task(self._step_recv()) + if send_data in done: + send_data = asyncio.create_task(self._step_send()) + if not done: + self._on_ping_timeout() + finally: + if pending: + for task in pending: + task.cancel() + await asyncio.wait(pending) async def _step_recv(self) -> None: recv_data = await self._reader.read(MAXIMUM_DATA) - updates = self._on_net_read(recv_data) - self._updates.extend(updates) + result = self._on_net_read(recv_data) + self._updates.extend(result) async def _step_send(self) -> None: self._try_fill_write()