Refactor sender

This commit is contained in:
Jahongir Qurbonov 2024-10-18 12:54:53 +05:00
parent 78459b50e5
commit 8d81706bf5
2 changed files with 8 additions and 7 deletions

View File

@ -5,6 +5,7 @@ import itertools
import logging import logging
import platform import platform
import re import re
from asyncio import CancelledError
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Optional, TypeVar from typing import TYPE_CHECKING, Optional, TypeVar
@ -272,6 +273,9 @@ async def step_sender(client: Client) -> None:
else: else:
# disconnect was called, so the socket returning 0 bytes is expected # disconnect was called, so the socket returning 0 bytes is expected
return return
except CancelledError:
await disconnect(client)
return
process_socket_updates(client, updates) process_socket_updates(client, updates)

View File

@ -209,8 +209,9 @@ class Sender:
) )
async def disconnect(self) -> None: async def disconnect(self) -> None:
assert self._recv_task if not self._recv_task or not self._send_task:
assert self._send_task return
recv_task, send_task = self._recv_task, self._send_task recv_task, send_task = self._recv_task, self._send_task
async with self._lock: async with self._lock:
@ -261,11 +262,10 @@ class Sender:
async def _step(self) -> None: async def _step(self) -> None:
if self._step_counter == 0: if self._step_counter == 0:
self._try_fill_write()
self._recv_task = asyncio.create_task(self._do_recv()) self._recv_task = asyncio.create_task(self._do_recv())
self._send_task = asyncio.create_task(self._do_send()) self._send_task = asyncio.create_task(self._do_send())
if self._recv_task is None or self._send_task is None: if not self._recv_task or not self._send_task:
# Disconnected # Disconnected
return return
@ -300,9 +300,6 @@ class Sender:
self._request_event.clear() self._request_event.clear()
def _try_fill_write(self) -> None: def _try_fill_write(self) -> None:
if not self._requests:
return
for request in self._requests: for request in self._requests:
if isinstance(request.state, NotSerialized): if isinstance(request.state, NotSerialized):
if (msg_id := self._mtp.push(request.body)) is not None: if (msg_id := self._mtp.push(request.body)) is not None: