mirror of
https://github.com/LonamiWebs/Telethon.git
synced 2025-04-09 11:44:15 +03:00
Remove enqueuer abstraction from sender
Unnecessary complexity since Python lacks exclusive ownership.
This commit is contained in:
parent
77b49a1c88
commit
2e1321b6c9
|
@ -2,9 +2,9 @@ import asyncio
|
|||
import struct
|
||||
import time
|
||||
from abc import ABC
|
||||
from asyncio import FIRST_COMPLETED, Future, Queue, StreamReader, StreamWriter
|
||||
from asyncio import FIRST_COMPLETED, Event, Future, StreamReader, StreamWriter
|
||||
from dataclasses import dataclass
|
||||
from typing import BinaryIO, Generic, List, Optional, Self, Tuple, TypeVar
|
||||
from typing import Generic, List, Optional, Self, TypeVar
|
||||
|
||||
from ..crypto.auth_key import AuthKey
|
||||
from ..mtproto import authentication
|
||||
|
@ -68,22 +68,6 @@ class Request(Generic[Return]):
|
|||
result: Future[Return]
|
||||
|
||||
|
||||
class Enqueuer:
|
||||
__slots__ = ("_queue",)
|
||||
|
||||
def __init__(self, queue: Queue[Request[object]]) -> None:
|
||||
self._queue = queue
|
||||
|
||||
def enqueue(self, request: RemoteCall[Return]) -> Future[Return]:
|
||||
body = bytes(request)
|
||||
assert len(body) >= 4
|
||||
oneshot = asyncio.get_running_loop().create_future()
|
||||
self._queue.put_nowait(
|
||||
Request(body=body, state=NotSerialized(), result=oneshot)
|
||||
)
|
||||
return oneshot
|
||||
|
||||
|
||||
@dataclass
|
||||
class Sender:
|
||||
_reader: StreamReader
|
||||
|
@ -92,38 +76,37 @@ class Sender:
|
|||
_mtp: Mtp
|
||||
_mtp_buffer: bytearray
|
||||
_requests: List[Request[object]]
|
||||
_request_rx: Queue[Request[object]]
|
||||
_request_event: Event
|
||||
_next_ping: float
|
||||
_read_buffer: bytearray
|
||||
_write_drain_pending: bool
|
||||
|
||||
@classmethod
|
||||
async def connect(
|
||||
cls, transport: Transport, mtp: Mtp, addr: str
|
||||
) -> Tuple[Self, Enqueuer]:
|
||||
async def connect(cls, transport: Transport, mtp: Mtp, addr: str) -> Self:
|
||||
reader, writer = await asyncio.open_connection(*addr.split(":"))
|
||||
request_queue: Queue[Request[object]] = Queue()
|
||||
|
||||
return (
|
||||
cls(
|
||||
_reader=reader,
|
||||
_writer=writer,
|
||||
_transport=transport,
|
||||
_mtp=mtp,
|
||||
_mtp_buffer=bytearray(),
|
||||
_requests=[],
|
||||
_request_rx=request_queue,
|
||||
_next_ping=asyncio.get_running_loop().time() + PING_DELAY,
|
||||
_read_buffer=bytearray(),
|
||||
_write_drain_pending=False,
|
||||
),
|
||||
Enqueuer(request_queue),
|
||||
return cls(
|
||||
_reader=reader,
|
||||
_writer=writer,
|
||||
_transport=transport,
|
||||
_mtp=mtp,
|
||||
_mtp_buffer=bytearray(),
|
||||
_requests=[],
|
||||
_request_event=Event(),
|
||||
_next_ping=asyncio.get_running_loop().time() + PING_DELAY,
|
||||
_read_buffer=bytearray(),
|
||||
_write_drain_pending=False,
|
||||
)
|
||||
|
||||
async def disconnect(self):
|
||||
async def disconnect(self) -> None:
|
||||
self._writer.close()
|
||||
await self._writer.wait_closed()
|
||||
|
||||
def enqueue(self, request: RemoteCall[Return]) -> Future[bytes]:
|
||||
rx = self._enqueue_body(bytes(request))
|
||||
self._request_event.set()
|
||||
return rx
|
||||
|
||||
async def invoke(self, request: RemoteCall[Return]) -> bytes:
|
||||
rx = self._enqueue_body(bytes(request))
|
||||
return await self._step_until_receive(rx)
|
||||
|
@ -146,7 +129,7 @@ class Sender:
|
|||
async def step(self) -> List[Updates]:
|
||||
self._try_fill_write()
|
||||
|
||||
recv_req = asyncio.create_task(self._request_rx.get())
|
||||
recv_req = asyncio.create_task(self._request_event.wait())
|
||||
recv_data = asyncio.create_task(self._reader.read(MAXIMUM_DATA))
|
||||
send_data = asyncio.create_task(self._do_send())
|
||||
done, pending = await asyncio.wait(
|
||||
|
@ -161,7 +144,7 @@ class Sender:
|
|||
|
||||
result = []
|
||||
if recv_req in done:
|
||||
self._requests.append(recv_req.result())
|
||||
self._request_event.clear()
|
||||
if recv_data in done:
|
||||
result = self._on_net_read(recv_data.result())
|
||||
if send_data in done:
|
||||
|
@ -281,15 +264,12 @@ class Sender:
|
|||
return None
|
||||
|
||||
|
||||
async def connect(transport: Transport, addr: str) -> Tuple[Sender, Enqueuer]:
|
||||
sender, enqueuer = await Sender.connect(transport, Plain(), addr)
|
||||
return await generate_auth_key(sender, enqueuer)
|
||||
async def connect(transport: Transport, addr: str) -> Sender:
|
||||
sender = await Sender.connect(transport, Plain(), addr)
|
||||
return await generate_auth_key(sender)
|
||||
|
||||
|
||||
async def generate_auth_key(
|
||||
sender: Sender,
|
||||
enqueuer: Enqueuer,
|
||||
) -> Tuple[Sender, Enqueuer]:
|
||||
async def generate_auth_key(sender: Sender) -> Sender:
|
||||
request, data1 = authentication.step1()
|
||||
response = await sender.send(request)
|
||||
request, data2 = authentication.step2(data1, response)
|
||||
|
@ -301,20 +281,17 @@ async def generate_auth_key(
|
|||
time_offset = finished.time_offset
|
||||
first_salt = finished.first_salt
|
||||
|
||||
return (
|
||||
Sender(
|
||||
_reader=sender._reader,
|
||||
_writer=sender._writer,
|
||||
_transport=sender._transport,
|
||||
_mtp=Encrypted(auth_key, time_offset=time_offset, first_salt=first_salt),
|
||||
_mtp_buffer=sender._mtp_buffer,
|
||||
_requests=sender._requests,
|
||||
_request_rx=sender._request_rx,
|
||||
_next_ping=time.time() + PING_DELAY,
|
||||
_read_buffer=sender._read_buffer,
|
||||
_write_drain_pending=sender._write_drain_pending,
|
||||
),
|
||||
enqueuer,
|
||||
return Sender(
|
||||
_reader=sender._reader,
|
||||
_writer=sender._writer,
|
||||
_transport=sender._transport,
|
||||
_mtp=Encrypted(auth_key, time_offset=time_offset, first_salt=first_salt),
|
||||
_mtp_buffer=sender._mtp_buffer,
|
||||
_requests=sender._requests,
|
||||
_request_event=sender._request_event,
|
||||
_next_ping=time.time() + PING_DELAY,
|
||||
_read_buffer=sender._read_buffer,
|
||||
_write_drain_pending=sender._write_drain_pending,
|
||||
)
|
||||
|
||||
|
||||
|
@ -322,7 +299,7 @@ async def connect_with_auth(
|
|||
transport: Transport,
|
||||
addr: str,
|
||||
auth_key: bytes,
|
||||
) -> Tuple[Sender, Enqueuer]:
|
||||
) -> Sender:
|
||||
return await Sender.connect(
|
||||
transport, Encrypted(AuthKey.from_bytes(auth_key)), addr
|
||||
)
|
||||
|
|
|
@ -22,11 +22,11 @@ def test_invoke_encrypted_method(caplog: LogCaptureFixture) -> None:
|
|||
def timeout() -> float:
|
||||
return deadline - asyncio.get_running_loop().time()
|
||||
|
||||
sender, enqueuer = await asyncio.wait_for(
|
||||
sender = await asyncio.wait_for(
|
||||
connect(Full(), TELEGRAM_DEFAULT_TEST_DC), timeout()
|
||||
)
|
||||
|
||||
rx = enqueuer.enqueue(
|
||||
rx = sender.enqueue(
|
||||
functions.invoke_with_layer(
|
||||
layer=LAYER,
|
||||
query=functions.init_connection(
|
||||
|
|
Loading…
Reference in New Issue
Block a user