diff --git a/client/src/telethon/_impl/mtsender/sender.py b/client/src/telethon/_impl/mtsender/sender.py index 22055479..7a327a60 100644 --- a/client/src/telethon/_impl/mtsender/sender.py +++ b/client/src/telethon/_impl/mtsender/sender.py @@ -2,9 +2,9 @@ import asyncio import struct import time from abc import ABC -from asyncio import FIRST_COMPLETED, Future, Queue, StreamReader, StreamWriter +from asyncio import FIRST_COMPLETED, Event, Future, StreamReader, StreamWriter from dataclasses import dataclass -from typing import BinaryIO, Generic, List, Optional, Self, Tuple, TypeVar +from typing import Generic, List, Optional, Self, TypeVar from ..crypto.auth_key import AuthKey from ..mtproto import authentication @@ -68,22 +68,6 @@ class Request(Generic[Return]): result: Future[Return] -class Enqueuer: - __slots__ = ("_queue",) - - def __init__(self, queue: Queue[Request[object]]) -> None: - self._queue = queue - - def enqueue(self, request: RemoteCall[Return]) -> Future[Return]: - body = bytes(request) - assert len(body) >= 4 - oneshot = asyncio.get_running_loop().create_future() - self._queue.put_nowait( - Request(body=body, state=NotSerialized(), result=oneshot) - ) - return oneshot - - @dataclass class Sender: _reader: StreamReader @@ -92,38 +76,37 @@ class Sender: _mtp: Mtp _mtp_buffer: bytearray _requests: List[Request[object]] - _request_rx: Queue[Request[object]] + _request_event: Event _next_ping: float _read_buffer: bytearray _write_drain_pending: bool @classmethod - async def connect( - cls, transport: Transport, mtp: Mtp, addr: str - ) -> Tuple[Self, Enqueuer]: + async def connect(cls, transport: Transport, mtp: Mtp, addr: str) -> Self: reader, writer = await asyncio.open_connection(*addr.split(":")) - request_queue: Queue[Request[object]] = Queue() - return ( - cls( - _reader=reader, - _writer=writer, - _transport=transport, - _mtp=mtp, - _mtp_buffer=bytearray(), - _requests=[], - _request_rx=request_queue, - _next_ping=asyncio.get_running_loop().time() + PING_DELAY, - _read_buffer=bytearray(), - _write_drain_pending=False, - ), - Enqueuer(request_queue), + return cls( + _reader=reader, + _writer=writer, + _transport=transport, + _mtp=mtp, + _mtp_buffer=bytearray(), + _requests=[], + _request_event=Event(), + _next_ping=asyncio.get_running_loop().time() + PING_DELAY, + _read_buffer=bytearray(), + _write_drain_pending=False, ) - async def disconnect(self): + async def disconnect(self) -> None: self._writer.close() await self._writer.wait_closed() + def enqueue(self, request: RemoteCall[Return]) -> Future[bytes]: + rx = self._enqueue_body(bytes(request)) + self._request_event.set() + return rx + async def invoke(self, request: RemoteCall[Return]) -> bytes: rx = self._enqueue_body(bytes(request)) return await self._step_until_receive(rx) @@ -146,7 +129,7 @@ class Sender: async def step(self) -> List[Updates]: self._try_fill_write() - recv_req = asyncio.create_task(self._request_rx.get()) + recv_req = asyncio.create_task(self._request_event.wait()) recv_data = asyncio.create_task(self._reader.read(MAXIMUM_DATA)) send_data = asyncio.create_task(self._do_send()) done, pending = await asyncio.wait( @@ -161,7 +144,7 @@ class Sender: result = [] if recv_req in done: - self._requests.append(recv_req.result()) + self._request_event.clear() if recv_data in done: result = self._on_net_read(recv_data.result()) if send_data in done: @@ -281,15 +264,12 @@ class Sender: return None -async def connect(transport: Transport, addr: str) -> Tuple[Sender, Enqueuer]: - sender, enqueuer = await Sender.connect(transport, Plain(), addr) - return await generate_auth_key(sender, enqueuer) +async def connect(transport: Transport, addr: str) -> Sender: + sender = await Sender.connect(transport, Plain(), addr) + return await generate_auth_key(sender) -async def generate_auth_key( - sender: Sender, - enqueuer: Enqueuer, -) -> Tuple[Sender, Enqueuer]: +async def generate_auth_key(sender: Sender) -> Sender: request, data1 = authentication.step1() response = await sender.send(request) request, data2 = authentication.step2(data1, response) @@ -301,20 +281,17 @@ async def generate_auth_key( time_offset = finished.time_offset first_salt = finished.first_salt - return ( - Sender( - _reader=sender._reader, - _writer=sender._writer, - _transport=sender._transport, - _mtp=Encrypted(auth_key, time_offset=time_offset, first_salt=first_salt), - _mtp_buffer=sender._mtp_buffer, - _requests=sender._requests, - _request_rx=sender._request_rx, - _next_ping=time.time() + PING_DELAY, - _read_buffer=sender._read_buffer, - _write_drain_pending=sender._write_drain_pending, - ), - enqueuer, + return Sender( + _reader=sender._reader, + _writer=sender._writer, + _transport=sender._transport, + _mtp=Encrypted(auth_key, time_offset=time_offset, first_salt=first_salt), + _mtp_buffer=sender._mtp_buffer, + _requests=sender._requests, + _request_event=sender._request_event, + _next_ping=time.time() + PING_DELAY, + _read_buffer=sender._read_buffer, + _write_drain_pending=sender._write_drain_pending, ) @@ -322,7 +299,7 @@ async def connect_with_auth( transport: Transport, addr: str, auth_key: bytes, -) -> Tuple[Sender, Enqueuer]: +) -> Sender: return await Sender.connect( transport, Encrypted(AuthKey.from_bytes(auth_key)), addr ) diff --git a/client/tests/mtsender_test.py b/client/tests/mtsender_test.py index a23cfcb8..a666486b 100644 --- a/client/tests/mtsender_test.py +++ b/client/tests/mtsender_test.py @@ -22,11 +22,11 @@ def test_invoke_encrypted_method(caplog: LogCaptureFixture) -> None: def timeout() -> float: return deadline - asyncio.get_running_loop().time() - sender, enqueuer = await asyncio.wait_for( + sender = await asyncio.wait_for( connect(Full(), TELEGRAM_DEFAULT_TEST_DC), timeout() ) - rx = enqueuer.enqueue( + rx = sender.enqueue( functions.invoke_with_layer( layer=LAYER, query=functions.init_connection(