Remove enqueuer abstraction from sender

Unnecessary complexity since Python lacks exclusive ownership.
This commit is contained in:
Lonami Exo 2023-09-01 11:57:41 +02:00
parent 77b49a1c88
commit 2e1321b6c9
2 changed files with 41 additions and 64 deletions

View File

@ -2,9 +2,9 @@ import asyncio
import struct
import time
from abc import ABC
from asyncio import FIRST_COMPLETED, Future, Queue, StreamReader, StreamWriter
from asyncio import FIRST_COMPLETED, Event, Future, StreamReader, StreamWriter
from dataclasses import dataclass
from typing import BinaryIO, Generic, List, Optional, Self, Tuple, TypeVar
from typing import Generic, List, Optional, Self, TypeVar
from ..crypto.auth_key import AuthKey
from ..mtproto import authentication
@ -68,22 +68,6 @@ class Request(Generic[Return]):
result: Future[Return]
class Enqueuer:
__slots__ = ("_queue",)
def __init__(self, queue: Queue[Request[object]]) -> None:
self._queue = queue
def enqueue(self, request: RemoteCall[Return]) -> Future[Return]:
body = bytes(request)
assert len(body) >= 4
oneshot = asyncio.get_running_loop().create_future()
self._queue.put_nowait(
Request(body=body, state=NotSerialized(), result=oneshot)
)
return oneshot
@dataclass
class Sender:
_reader: StreamReader
@ -92,38 +76,37 @@ class Sender:
_mtp: Mtp
_mtp_buffer: bytearray
_requests: List[Request[object]]
_request_rx: Queue[Request[object]]
_request_event: Event
_next_ping: float
_read_buffer: bytearray
_write_drain_pending: bool
@classmethod
async def connect(
cls, transport: Transport, mtp: Mtp, addr: str
) -> Tuple[Self, Enqueuer]:
async def connect(cls, transport: Transport, mtp: Mtp, addr: str) -> Self:
reader, writer = await asyncio.open_connection(*addr.split(":"))
request_queue: Queue[Request[object]] = Queue()
return (
cls(
_reader=reader,
_writer=writer,
_transport=transport,
_mtp=mtp,
_mtp_buffer=bytearray(),
_requests=[],
_request_rx=request_queue,
_next_ping=asyncio.get_running_loop().time() + PING_DELAY,
_read_buffer=bytearray(),
_write_drain_pending=False,
),
Enqueuer(request_queue),
return cls(
_reader=reader,
_writer=writer,
_transport=transport,
_mtp=mtp,
_mtp_buffer=bytearray(),
_requests=[],
_request_event=Event(),
_next_ping=asyncio.get_running_loop().time() + PING_DELAY,
_read_buffer=bytearray(),
_write_drain_pending=False,
)
async def disconnect(self):
async def disconnect(self) -> None:
self._writer.close()
await self._writer.wait_closed()
def enqueue(self, request: RemoteCall[Return]) -> Future[bytes]:
rx = self._enqueue_body(bytes(request))
self._request_event.set()
return rx
async def invoke(self, request: RemoteCall[Return]) -> bytes:
rx = self._enqueue_body(bytes(request))
return await self._step_until_receive(rx)
@ -146,7 +129,7 @@ class Sender:
async def step(self) -> List[Updates]:
self._try_fill_write()
recv_req = asyncio.create_task(self._request_rx.get())
recv_req = asyncio.create_task(self._request_event.wait())
recv_data = asyncio.create_task(self._reader.read(MAXIMUM_DATA))
send_data = asyncio.create_task(self._do_send())
done, pending = await asyncio.wait(
@ -161,7 +144,7 @@ class Sender:
result = []
if recv_req in done:
self._requests.append(recv_req.result())
self._request_event.clear()
if recv_data in done:
result = self._on_net_read(recv_data.result())
if send_data in done:
@ -281,15 +264,12 @@ class Sender:
return None
async def connect(transport: Transport, addr: str) -> Tuple[Sender, Enqueuer]:
sender, enqueuer = await Sender.connect(transport, Plain(), addr)
return await generate_auth_key(sender, enqueuer)
async def connect(transport: Transport, addr: str) -> Sender:
sender = await Sender.connect(transport, Plain(), addr)
return await generate_auth_key(sender)
async def generate_auth_key(
sender: Sender,
enqueuer: Enqueuer,
) -> Tuple[Sender, Enqueuer]:
async def generate_auth_key(sender: Sender) -> Sender:
request, data1 = authentication.step1()
response = await sender.send(request)
request, data2 = authentication.step2(data1, response)
@ -301,20 +281,17 @@ async def generate_auth_key(
time_offset = finished.time_offset
first_salt = finished.first_salt
return (
Sender(
_reader=sender._reader,
_writer=sender._writer,
_transport=sender._transport,
_mtp=Encrypted(auth_key, time_offset=time_offset, first_salt=first_salt),
_mtp_buffer=sender._mtp_buffer,
_requests=sender._requests,
_request_rx=sender._request_rx,
_next_ping=time.time() + PING_DELAY,
_read_buffer=sender._read_buffer,
_write_drain_pending=sender._write_drain_pending,
),
enqueuer,
return Sender(
_reader=sender._reader,
_writer=sender._writer,
_transport=sender._transport,
_mtp=Encrypted(auth_key, time_offset=time_offset, first_salt=first_salt),
_mtp_buffer=sender._mtp_buffer,
_requests=sender._requests,
_request_event=sender._request_event,
_next_ping=time.time() + PING_DELAY,
_read_buffer=sender._read_buffer,
_write_drain_pending=sender._write_drain_pending,
)
@ -322,7 +299,7 @@ async def connect_with_auth(
transport: Transport,
addr: str,
auth_key: bytes,
) -> Tuple[Sender, Enqueuer]:
) -> Sender:
return await Sender.connect(
transport, Encrypted(AuthKey.from_bytes(auth_key)), addr
)

View File

@ -22,11 +22,11 @@ def test_invoke_encrypted_method(caplog: LogCaptureFixture) -> None:
def timeout() -> float:
return deadline - asyncio.get_running_loop().time()
sender, enqueuer = await asyncio.wait_for(
sender = await asyncio.wait_for(
connect(Full(), TELEGRAM_DEFAULT_TEST_DC), timeout()
)
rx = enqueuer.enqueue(
rx = sender.enqueue(
functions.invoke_with_layer(
layer=LAYER,
query=functions.init_connection(