Support downloading web documents

This commit is contained in:
Lonami Exo 2018-08-01 00:37:25 +02:00
parent 6d1bc227aa
commit 76c7217000
2 changed files with 64 additions and 13 deletions

View File

@ -8,6 +8,12 @@ from .users import UserMethods
from .. import utils, helpers, errors from .. import utils, helpers, errors
from ..tl import TLObject, types, functions from ..tl import TLObject, types, functions
try:
import aiohttp
except ImportError:
aiohttp = None
__log__ = logging.getLogger(__name__) __log__ = logging.getLogger(__name__)
@ -140,6 +146,10 @@ class DownloadMethods(UserMethods):
return self._download_contact( return self._download_contact(
media, file media, file
) )
elif isinstance(media, (types.WebDocument, types.WebDocumentNoProxy)):
return await self._download_web_document(
media, file, progress_callback
)
async def download_file( async def download_file(
self, input_location, file=None, *, part_size_kb=None, self, input_location, file=None, *, part_size_kb=None,
@ -298,19 +308,12 @@ class DownloadMethods(UserMethods):
progress_callback=progress_callback) progress_callback=progress_callback)
return file return file
async def _download_document( @staticmethod
self, document, file, date, progress_callback): def _get_kind_and_names(attributes):
"""Specialized version of .download_media() for documents.""" """Gets kind and possible names for :tl:`DocumentAttribute`."""
if isinstance(document, types.MessageMediaDocument):
document = document.document
if not isinstance(document, types.Document):
return
file_size = document.size
kind = 'document' kind = 'document'
possible_names = [] possible_names = []
for attr in document.attributes: for attr in attributes:
if isinstance(attr, types.DocumentAttributeFilename): if isinstance(attr, types.DocumentAttributeFilename):
possible_names.insert(0, attr.file_name) possible_names.insert(0, attr.file_name)
@ -327,13 +330,24 @@ class DownloadMethods(UserMethods):
elif attr.voice: elif attr.voice:
kind = 'voice' kind = 'voice'
return kind, possible_names
async def _download_document(
self, document, file, date, progress_callback):
"""Specialized version of .download_media() for documents."""
if isinstance(document, types.MessageMediaDocument):
document = document.document
if not isinstance(document, types.Document):
return
kind, possible_names = self._get_kind_and_names(document.attributes)
file = self._get_proper_filename( file = self._get_proper_filename(
file, kind, utils.get_extension(document), file, kind, utils.get_extension(document),
date=date, possible_names=possible_names date=date, possible_names=possible_names
) )
await self.download_file( await self.download_file(
document, file, file_size=file_size, document, file, file_size=document.size,
progress_callback=progress_callback) progress_callback=progress_callback)
return file return file
@ -373,6 +387,42 @@ class DownloadMethods(UserMethods):
return file return file
@classmethod
async def _download_web_document(cls, web, file, progress_callback):
"""
Specialized version of .download_media() for web documents.
"""
if not aiohttp:
raise ValueError(
'Cannot download web documents without the aiohttp '
'dependency install it (pip install aiohttp)'
)
# TODO Better way to get opened handles of files and auto-close
if isinstance(file, str):
kind, possible_names = cls._get_kind_and_names(web.attributes)
file = cls._get_proper_filename(
file, kind, utils.get_extension(web),
possible_names=possible_names
)
f = open(file, 'wb')
else:
f = file
try:
with aiohttp.ClientSession() as session:
# TODO Use progress_callback; get content length from response
# https://github.com/telegramdesktop/tdesktop/blob/c7e773dd9aeba94e2be48c032edc9a78bb50234e/Telegram/SourceFiles/ui/images.cpp#L1318-L1319
async with session.get(web.url) as response:
while True:
chunk = await response.content.read(128 * 1024)
if not chunk:
break
f.write(chunk)
finally:
if isinstance(file, str):
f.close()
@staticmethod @staticmethod
def _get_proper_filename(file, kind, extension, def _get_proper_filename(file, kind, extension,
date=None, possible_names=None): date=None, possible_names=None):

View File

@ -92,7 +92,8 @@ def get_extension(media):
# Documents will come with a mime type # Documents will come with a mime type
if isinstance(media, types.MessageMediaDocument): if isinstance(media, types.MessageMediaDocument):
media = media.document media = media.document
if isinstance(media, types.Document): if isinstance(media, (
types.Document, types.WebDocument, types.WebDocumentNoProxy)):
if media.mime_type == 'application/octet-stream': if media.mime_type == 'application/octet-stream':
# Octet stream are just bytes, which have no default extension # Octet stream are just bytes, which have no default extension
return '' return ''