Initial implementation of mtsender

This commit is contained in:
Lonami Exo 2023-08-31 10:48:59 +02:00
parent 84e14c2ec4
commit 53ea35e0d5
6 changed files with 368 additions and 4 deletions

View File

@ -77,13 +77,13 @@ CONTAINER_HEADER_LEN = (8 + 4 + 4) + (4 + 4) # msg_id, seq_no, size, constructo
class Encrypted(Mtp):
def __init__(
self,
auth_key: bytes,
auth_key: AuthKey,
*,
time_offset: Optional[int] = None,
first_salt: Optional[int] = None,
compression_threshold: Optional[int] = DEFAULT_COMPRESSION_THRESHOLD,
) -> None:
self._auth_key: AuthKey = AuthKey.from_bytes(auth_key)
self._auth_key = auth_key
self._time_offset: int = time_offset or 0
self._salts: List[FutureSalt] = [
FutureSalt(valid_since=0, valid_until=0x7FFFFFFF, salt=first_salt or 0)

View File

@ -48,5 +48,5 @@ class Plain(Mtp):
)
return Deserialization(
rpc_results=[(MsgId(0), payload[20 : 20 + length])], updates=[]
rpc_results=[(MsgId(0), bytes(payload[20 : 20 + length]))], updates=[]
)

View File

@ -1,7 +1,6 @@
from abc import ABC, abstractmethod
from typing import Callable
OutFn = Callable[[bytes | bytearray | memoryview], None]

View File

@ -0,0 +1,335 @@
import asyncio
import struct
import time
from abc import ABC
from asyncio import FIRST_COMPLETED, Future, Queue, StreamReader, StreamWriter
from dataclasses import dataclass
from typing import BinaryIO, Generic, List, Optional, Self, Tuple, TypeVar
from ..crypto.auth_key import AuthKey
from ..mtproto import authentication
from ..mtproto.mtp.encrypted import Encrypted
from ..mtproto.mtp.plain import Plain
from ..mtproto.mtp.types import MsgId, Mtp
from ..mtproto.transport.abcs import MissingBytes, Transport
from ..tl.abcs import Updates
from ..tl.core.request import Request as RemoteCall
from ..tl.mtproto.functions import ping_delay_disconnect
MAXIMUM_DATA = (1024 * 1024) + (8 * 1024)
PING_DELAY = 60
NO_PING_DISCONNECT = 75
assert NO_PING_DISCONNECT > PING_DELAY
_last_id = 0
def generate_random_id() -> int:
global _last_id
if _last_id == 0:
_last_id = int(time.time() * 0x100000000)
_last_id += 1
return _last_id
class RequestState(ABC):
pass
class NotSerialized(RequestState):
pass
class Serialized(RequestState):
__slots__ = ("msg_id",)
def __init__(self, msg_id: MsgId):
self.msg_id = msg_id
class Sent(RequestState):
__slots__ = ("msg_id",)
def __init__(self, msg_id: MsgId):
self.msg_id = msg_id
Return = TypeVar("Return")
@dataclass
class Request(Generic[Return]):
body: bytes
state: RequestState
result: Future[Return]
class Enqueuer:
__slots__ = ("_queue",)
def __init__(
self,
) -> None:
# TODO use a bound
self._queue: Queue[Request[object]] = Queue()
def enqueue(self, request: RemoteCall[Return]) -> Future[Return]:
body = bytes(request)
assert len(body) >= 4
oneshot = asyncio.get_running_loop().create_future()
self._queue.put_nowait(
Request(body=body, state=NotSerialized(), result=oneshot)
)
return oneshot
@dataclass
class Sender:
_reader: StreamReader
_writer: StreamWriter
_transport: Transport
_mtp: Mtp
_mtp_buffer: bytearray
_requests: List[Request[object]]
_request_tx: Queue[Request[object]]
_request_rx: Queue[Request[object]]
_next_ping: float
_read_buffer: bytearray
_write_drain_pending: bool
@classmethod
async def connect(
cls, transport: Transport, mtp: Mtp, addr: str
) -> Tuple[Self, Enqueuer]:
reader, writer = await asyncio.open_connection(*addr.split(":"))
tx: Queue[object] = Queue()
rx = tx
return (
cls(
_reader=reader,
_writer=writer,
_transport=transport,
_mtp=mtp,
_mtp_buffer=bytearray(),
_requests=[],
_request_tx=tx,
_request_rx=rx,
_next_ping=asyncio.get_running_loop().time() + PING_DELAY,
_read_buffer=bytearray(),
_write_drain_pending=False,
),
Enqueuer(),
)
async def invoke(self, request: RemoteCall[Return]) -> bytes:
rx = self._enqueue_body(bytes(request))
return await self._step_until_receive(rx)
async def send(self, body: bytes) -> bytes:
rx = self._enqueue_body(body)
return await self._step_until_receive(rx)
def _enqueue_body(self, body: bytes) -> Future[object]:
oneshot = asyncio.get_running_loop().create_future()
self._requests.append(Request(body=body, state=NotSerialized(), result=oneshot))
return oneshot
async def _step_until_receive(self, rx: Future[object]) -> bytes:
while True:
await self.step()
if rx.done():
return rx.result()
async def step(self) -> List[Updates]:
self._try_fill_write()
recv_req = asyncio.create_task(self._request_rx.get())
recv_data = asyncio.create_task(self._reader.read(MAXIMUM_DATA))
send_data = asyncio.create_task(self._do_send())
done, pending = await asyncio.wait(
(recv_req, recv_data, send_data),
timeout=self._next_ping - asyncio.get_running_loop().time(),
return_when=FIRST_COMPLETED,
)
for task in pending:
task.cancel()
await asyncio.wait(pending)
result = []
if recv_req in done:
self._requests.append(recv_req.result())
if recv_data in done:
result = self._on_net_read(recv_data.result())
if send_data in done:
self._on_net_write()
if not done:
self._on_ping_timeout()
return result
async def _do_send(self) -> None:
if self._write_drain_pending:
await self._writer.drain()
self._write_drain_pending = False
else:
# Never return
await asyncio.get_running_loop().create_future()
def _try_fill_write(self) -> None:
if self._write_drain_pending:
return
# TODO test that the a request is only ever sent onrece
requests = [r for r in self._requests if isinstance(r.state, NotSerialized)]
if not requests:
return
msg_ids = []
for request in requests:
if (msg_id := self._mtp.push(request.body)) is not None:
msg_ids.append(msg_id)
else:
break
mtp_buffer = self._mtp.finalize()
self._transport.pack(mtp_buffer, self._writer.write)
self._write_drain_pending = True
for req, msg_id in zip(requests, msg_ids):
req.state = Serialized(msg_id)
def _on_net_read(self, read_buffer: bytes) -> List[Updates]:
if not read_buffer:
raise ConnectionResetError("read 0 bytes")
self._read_buffer += read_buffer
updates = []
while self._read_buffer:
self._mtp_buffer.clear()
try:
n = self._transport.unpack(self._read_buffer, self._mtp_buffer)
except MissingBytes:
break
else:
del self._read_buffer[:n]
self._process_mtp_buffer(updates)
return updates
def _on_net_write(self) -> None:
for req in self._requests:
if isinstance(req.state, Serialized):
req.state = Sent(req.state.msg_id)
def _on_ping_timeout(self) -> None:
ping_id = generate_random_id()
self._enqueue_body(
bytes(
ping_delay_disconnect(
ping_id=ping_id, disconnect_delay=NO_PING_DISCONNECT
)
)
)
self._next_ping = time.time() + PING_DELAY
def _process_mtp_buffer(self, updates: List[Updates]) -> None:
result = self._mtp.deserialize(self._mtp_buffer)
for update in result.updates:
try:
u = Updates.from_bytes(update)
except ValueError:
pass # TODO log
else:
updates.append(u)
for msg_id, ret in result.rpc_results:
found = False
for i in reversed(range(len(self._requests))):
req = self._requests[i]
if isinstance(req.state, Serialized) and req.state.msg_id == msg_id:
raise RuntimeError("got rpc result for unsent request")
if isinstance(req.state, Sent) and req.state.msg_id == msg_id:
found = True
if isinstance(ret, bytes):
assert len(ret) >= 4
elif isinstance(ret, Exception):
raise NotImplementedError
elif isinstance(ret, RpcError):
ret.caused_by = req.body[:4]
raise ret
elif isinstance(ret, Dropped):
raise ret
elif isinstance(ret, Deserialize):
raise ret
elif isinstance(ret, BadMessage):
# TODO test that we resend the request
req.state = NotSerialized()
break
else:
raise RuntimeError("unexpected case")
req = self._requests.pop(i)
req.result.set_result(ret)
break
if not found:
pass # TODO log
@property
def auth_key(self) -> Optional[bytes]:
if isinstance(self._mtp, Encrypted):
return self._mtp.auth_key
async def connect(transport, addr):
sender, enqueuer = await Sender.connect(transport, Plain(), addr)
return await generate_auth_key(sender, enqueuer)
async def generate_auth_key(
sender: Sender,
enqueuer: Enqueuer,
) -> Tuple[Sender, Enqueuer]:
request, data = authentication.step1()
response = await sender.send(request)
request, data = authentication.step2(data, response)
response = await sender.send(request)
request, data = authentication.step3(data, response)
response = await sender.send(request)
finished = authentication.create_key(data, response)
auth_key = finished.auth_key
time_offset = finished.time_offset
first_salt = finished.first_salt
return (
Sender(
_reader=sender._reader,
_writer=sender._writer,
_transport=sender._transport,
_mtp=Encrypted(auth_key, time_offset=time_offset, first_salt=first_salt),
_mtp_buffer=sender._mtp_buffer,
_requests=sender._requests,
_request_tx=sender._request_tx,
_request_rx=sender._request_rx,
_next_ping=time.time() + PING_DELAY,
_read_buffer=sender._read_buffer,
_write_drain_pending=sender._write_drain_pending,
),
enqueuer,
)
async def connect_with_auth(
transport: Transport,
addr: str,
auth_key: bytes,
) -> Tuple[Sender, Enqueuer]:
return await Sender.connect(
transport, Encrypted(AuthKey.from_bytes(auth_key)), addr
)

View File

@ -0,0 +1,30 @@
import asyncio
import logging
from telethon._impl.mtproto.transport.full import Full
from telethon._impl.mtsender.sender import connect
TELEGRAM_TEST_DC_2 = "149.154.167.40:443"
TELEGRAM_DEFAULT_TEST_DC = TELEGRAM_TEST_DC_2
TEST_TIMEOUT = 10000
def test_invoke_encrypted_method(caplog) -> None:
caplog.set_level(logging.DEBUG)
async def func():
deadline = asyncio.get_running_loop().time() + TEST_TIMEOUT
def timeout():
return deadline - asyncio.get_running_loop().time()
sender, enqueuer = await asyncio.wait_for(
connect(Full(), TELEGRAM_DEFAULT_TEST_DC), timeout()
)
# TODO test enqueuer
sender, enqueuer
asyncio.run(func())