Create a method to iterate downloads

This commit is contained in:
Lonami Exo 2019-05-21 16:16:16 +02:00
parent a43830d403
commit 0b0f8f4285

View File

@ -6,6 +6,7 @@ import typing
from .users import UserMethods
from .. import utils, helpers, errors, hints
from ..requestiter import RequestIter
from ..tl import TLObject, types, functions
try:
@ -17,6 +18,140 @@ if typing.TYPE_CHECKING:
from .telegramclient import TelegramClient
# Chunk sizes for upload.getFile must be multiples of the smallest size
MIN_CHUNK_SIZE = 4096
MAX_CHUNK_SIZE = 512 * 1024
class _DirectDownloadIter(RequestIter):
async def _init(
self, file, dc_id, offset, stride, chunk_size, request_size, file_size
):
self.request = functions.upload.GetFileRequest(
file, offset=offset, limit=request_size)
self.total = file_size
self._stride = stride
self._chunk_size = chunk_size
self._last_part = None
self._exported = dc_id and self.client.session.dc_id != dc_id
if not self._exported:
# The used sender will also change if ``FileMigrateError`` occurs
self._sender = self.client._sender
else:
try:
self._sender = await self.client._borrow_exported_sender(dc_id)
except errors.DcIdInvalidError:
# Can't export a sender for the ID we are currently in
config = await self.client(functions.help.GetConfigRequest())
for option in config.dc_options:
if option.ip_address == self.client.session.server_address:
self.client.session.set_dc(
option.id, option.ip_address, option.port)
self.client.session.save()
break
# TODO Figure out why the session may have the wrong DC ID
self._sender = self.client._sender
self._exported = False
async def _load_next_chunk(self):
cur = await self._request()
self.buffer.append(cur)
if len(cur) < self.request.limit:
self.left = len(self.buffer)
await self.close()
else:
self.request.offset += self._stride
async def _request(self):
try:
result = await self._sender.send(self.request)
if isinstance(result, types.upload.FileCdnRedirect):
raise NotImplementedError # TODO Implement
else:
return result.bytes
except errors.FileMigrateError as e:
self.client._log[__name__].info('File lives in another DC')
self._sender = await self.client._borrow_exported_sender(e.new_dc)
self._exported = True
return await self._request()
async def close(self):
if not self._sender:
return
try:
if self._exported:
await self.client._return_exported_sender(self._sender)
elif self._sender != self.client._sender:
await self._sender.disconnect()
finally:
self._sender = None
async def __aenter__(self):
pass
async def __aexit__(self, *args):
await self.close()
__enter__ = helpers._sync_enter
__exit__ = helpers._sync_exit
class _GenericDownloadIter(_DirectDownloadIter):
async def _load_next_chunk(self, mask=MIN_CHUNK_SIZE - 1):
# 1. Fetch enough for one chunk
data = b''
# 1.1. ``bad`` is how much into the data we have we need to offset
bad = self.request.offset & mask
before = self.request.offset
# 1.2. We have to fetch from a valid offset, so remove that bad part
self.request.offset -= bad
done = False
while not done and len(data) - bad < self._chunk_size:
cur = await self._request()
self.request.offset += self.request.limit
data += cur
done = len(cur) < self.request.limit
# 1.3 Restore our last desired offset
self.request.offset = before
# 2. Fill the buffer with the data we have
# 2.1. Slicing ``bytes`` is expensive, yield ``memoryview`` instead
mem = memoryview(data)
# 2.2. The current chunk starts at ``bad`` offset into the data,
# and each new chunk is ``stride`` bytes apart of the other
for i in range(bad, len(data), self._stride):
self.buffer.append(mem[i:i + self._chunk_size])
# 2.3. We will yield this offset, so move to the next one
self.request.offset += self._stride
# 2.4. If we are in the last chunk, we will return the last partial data
if done:
self.left = len(self.buffer)
await self.close()
return
# 2.5. If we are not done, we can't return incomplete chunks.
if len(self.buffer[-1]) != self._chunk_size:
self._last_part = self.buffer.pop().tobytes()
# 3. Be careful with the offsets. Re-fetching a bit of data
# is fine, since it greatly simplifies things.
# TODO Try to not re-fetch data
self.request.offset -= self._stride
class DownloadMethods(UserMethods):
# region Public methods
@ -366,6 +501,150 @@ class DownloadMethods(UserMethods):
if isinstance(file, str) or in_memory:
f.close()
def iter_download(
self: 'TelegramClient',
file: 'hints.FileLike',
*,
offset: int = 0,
stride: int = None,
limit: int = None,
chunk_size: int = None,
request_size: int = MAX_CHUNK_SIZE,
file_size: int = None,
dc_id: int = None
):
"""
Iterates over a file download, yielding chunks of the file.
This method can be used to stream files in a more convenient
way, since it offers more control (pausing, resuming, etc.)
.. note::
Using a value for `offset` or `stride` which is not a multiple
of the minimum allowed `request_size`, or if `chunk_size` is
different from `request_size`, the library will need to do a
bit more work to fetch the data in the way you intend it to.
You normally shouldn't worry about this.
Arguments
file (`hints.FileLike`):
The file of which contents you want to iterate over.
offset (`int`, optional):
The offset in bytes into the file from where the
download should start. For example, if a file is
1024KB long and you just want the last 512KB, you
would use ``offset=512 * 1024``.
stride (`int`, optional):
The stride of each chunk (how much the offset should
advance between reading each chunk). This parameter
should only be used for more advanced use cases.
It must be bigger than or equal to the `chunk_size`.
limit (`int`, optional):
The limit for how many *chunks* will be yielded at most.
chunk_size (`int`, optional):
The maximum size of the chunks that will be yielded.
Note that the last chunk may be less than this value.
By default, it equals to `request_size`.
request_size (`int`, optional):
How many bytes will be requested to Telegram when more
data is required. By default, as many bytes as possible
are requested. If you would like to request data in
smaller sizes, adjust this parameter.
Note that values outside the valid range will be clamped,
and the final value will also be a multiple of the minimum
allowed size.
file_size (`int`, optional):
If the file size is known beforehand, you should set
this parameter to said value. Depending on the type of
the input file passed, this may be set automatically.
dc_id (`int`, optional):
The data center the library should connect to in order
to download the file. You shouldn't worry about this.
Yields
``bytes`` objects representing the chunks of the file if the
right conditions are met, or ``memoryview`` objects instead.
Example
.. code-block:: python
# Streaming `media` to an output file
# After the iteration ends, the sender is cleaned up
with open('photo.jpg', 'wb') as fd:
for chunk client.iter_download(media):
fd.write(chunk)
# Fetching only the header of a file (32 bytes)
# You should manually close the iterator in this case.
stream = client.iter_download(media, request_size=32)
header = next(stream)
stream.close()
assert len(header) == 32
# Fetching only the header, inside of an ``async def``
async def main():
stream = client.iter_download(media, request_size=32)
header = await stream.__anext__()
await stream.close()
assert len(header) == 32
"""
if chunk_size is None:
chunk_size = request_size
if limit is None and file_size is not None:
limit = (file_size + chunk_size - 1) // chunk_size
if stride is None:
stride = chunk_size
elif stride < chunk_size:
raise ValueError('stride must be >= chunk_size')
request_size -= request_size % MIN_CHUNK_SIZE
if request_size < MIN_CHUNK_SIZE:
request_size = MIN_CHUNK_SIZE
elif request_size > MAX_CHUNK_SIZE:
request_size = MAX_CHUNK_SIZE
old_dc = dc_id
dc_id, file = utils.get_input_location(file)
if dc_id is None:
dc_id = old_dc
if chunk_size == request_size \
and offset % MIN_CHUNK_SIZE == 0 \
and stride % MIN_CHUNK_SIZE == 0:
cls = _DirectDownloadIter
self._log[__name__].info('Starting direct file download in chunks of '
'%d at %d, stride %d', request_size, offset, stride)
else:
cls = _GenericDownloadIter
self._log[__name__].info('Starting indirect file download in chunks of '
'%d at %d, stride %d', request_size, offset, stride)
return cls(
self,
limit,
file=file,
dc_id=dc_id,
offset=offset,
stride=stride,
chunk_size=chunk_size,
request_size=request_size,
file_size=file_size
)
# endregion
# region Private methods