This commit is contained in:
Pascal Jürgens 2020-01-17 12:04:05 +01:00
commit 58614e1570
3 changed files with 64 additions and 12 deletions

View File

@ -269,7 +269,7 @@ class UploadMethods:
# First check if the user passed an iterable, in which case # First check if the user passed an iterable, in which case
# we may want to send as an album if all are photo files. # we may want to send as an album if all are photo files.
if utils.is_list_like(file): if utils.is_list_like(file):
image_captions = [] media_captions = []
document_captions = [] document_captions = []
if utils.is_list_like(caption): if utils.is_list_like(caption):
captions = caption captions = caption
@ -277,28 +277,29 @@ class UploadMethods:
captions = [caption] captions = [caption]
# TODO Fix progress_callback # TODO Fix progress_callback
images = [] media = []
if force_document: if force_document:
documents = file documents = file
else: else:
documents = [] documents = []
for doc, cap in itertools.zip_longest(file, captions): for doc, cap in itertools.zip_longest(file, captions):
if utils.is_image(doc): if utils.is_image(doc) or utils.is_video(doc):
images.append(doc) media.append(doc)
image_captions.append(cap) media_captions.append(cap)
else: else:
documents.append(doc) documents.append(doc)
document_captions.append(cap) document_captions.append(cap)
result = [] result = []
while images: while media:
result += await self._send_album( result += await self._send_album(
entity, images[:10], caption=image_captions[:10], entity, media[:10], caption=media_captions[:10],
progress_callback=progress_callback, reply_to=reply_to, progress_callback=progress_callback, reply_to=reply_to,
parse_mode=parse_mode, silent=silent, schedule=schedule parse_mode=parse_mode, silent=silent, schedule=schedule,
supports_streaming=supports_streaming
) )
images = images[10:] media = media[10:]
image_captions = image_captions[10:] media_captions = media_captions[10:]
for doc, cap in zip(documents, captions): for doc, cap in zip(documents, captions):
result.append(await self.send_file( result.append(await self.send_file(
@ -349,7 +350,8 @@ class UploadMethods:
async def _send_album(self: 'TelegramClient', entity, files, caption='', async def _send_album(self: 'TelegramClient', entity, files, caption='',
progress_callback=None, reply_to=None, progress_callback=None, reply_to=None,
parse_mode=(), silent=None, schedule=None): parse_mode=(), silent=None, schedule=None,
supports_streaming=None):
"""Specialized version of .send_file for albums""" """Specialized version of .send_file for albums"""
# We don't care if the user wants to avoid cache, we will use it # We don't care if the user wants to avoid cache, we will use it
# anyway. Why? The cached version will be exactly the same thing # anyway. Why? The cached version will be exactly the same thing
@ -377,7 +379,8 @@ class UploadMethods:
# :tl:`InputMediaUploadedPhoto`. However using that will # :tl:`InputMediaUploadedPhoto`. However using that will
# make it `raise MediaInvalidError`, so we need to upload # make it `raise MediaInvalidError`, so we need to upload
# it as media and then convert that to :tl:`InputMediaPhoto`. # it as media and then convert that to :tl:`InputMediaPhoto`.
fh, fm, _ = await self._file_to_media(file) fh, fm, _ = await self._file_to_media(
file, supports_streaming=supports_streaming)
if isinstance(fm, types.InputMediaUploadedPhoto): if isinstance(fm, types.InputMediaUploadedPhoto):
r = await self(functions.messages.UploadMediaRequest( r = await self(functions.messages.UploadMediaRequest(
entity, media=fm entity, media=fm
@ -386,6 +389,15 @@ class UploadMethods:
fh.md5, fh.size, utils.get_input_photo(r.photo)) fh.md5, fh.size, utils.get_input_photo(r.photo))
fm = utils.get_input_media(r.photo) fm = utils.get_input_media(r.photo)
elif isinstance(fm, types.InputMediaUploadedDocument):
r = await self(functions.messages.UploadMediaRequest(
entity, media=fm
))
self.session.cache_file(
fh.md5, fh.size, utils.get_input_document(r.document))
fm = utils.get_input_media(
r.document, supports_streaming=supports_streaming)
if captions: if captions:
caption, msg_entities = captions.pop() caption, msg_entities = captions.pop()

View File

@ -12,6 +12,7 @@ import logging
import math import math
import mimetypes import mimetypes
import os import os
import pathlib
import re import re
import struct import struct
from collections import namedtuple from collections import namedtuple
@ -741,6 +742,8 @@ def _get_extension(file):
""" """
if isinstance(file, str): if isinstance(file, str):
return os.path.splitext(file)[-1] return os.path.splitext(file)[-1]
elif isinstance(file, pathlib.Path):
return file.suffix
elif isinstance(file, bytes): elif isinstance(file, bytes):
kind = imghdr.what(io.BytesIO(file)) kind = imghdr.what(io.BytesIO(file))
return ('.' + kind) if kind else '' return ('.' + kind) if kind else ''

View File

@ -1,3 +1,6 @@
import io
import pathlib
from telethon import utils from telethon import utils
from telethon.tl.types import ( from telethon.tl.types import (
MessageMediaGame, Game, PhotoEmpty MessageMediaGame, Game, PhotoEmpty
@ -16,3 +19,37 @@ def test_game_input_media_memory_error():
)) ))
input_media = utils.get_input_media(media) input_media = utils.get_input_media(media)
bytes(input_media) # <- shouldn't raise `MemoryError` bytes(input_media) # <- shouldn't raise `MemoryError`
def test_private_get_extension():
# Positive cases
png_header = bytes.fromhex('89 50 4e 47 0d 0a 1a 0a 00 00 00 0d 49 48 44 52')
png_buffer = io.BytesIO(png_header)
class CustomFd:
def __init__(self, name):
self.name = name
assert utils._get_extension('foo.bar.baz') == '.baz'
assert utils._get_extension(pathlib.Path('foo.bar.baz')) == '.baz'
assert utils._get_extension(png_header) == '.png'
assert utils._get_extension(png_buffer) == '.png'
assert utils._get_extension(png_buffer) == '.png' # make sure it did seek back
assert utils._get_extension(CustomFd('foo.bar.baz')) == '.baz'
# Negative cases
null_header = bytes.fromhex('00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00')
null_buffer = io.BytesIO(null_header)
empty_header = bytes()
empty_buffer = io.BytesIO(empty_header)
assert utils._get_extension('foo') == ''
assert utils._get_extension(pathlib.Path('foo')) == ''
assert utils._get_extension(null_header) == ''
assert utils._get_extension(null_buffer) == ''
assert utils._get_extension(null_buffer) == '' # make sure it did seek back
assert utils._get_extension(empty_header) == ''
assert utils._get_extension(empty_buffer) == ''
assert utils._get_extension(empty_buffer) == '' # make sure it did seek back
assert utils._get_extension(CustomFd('foo')) == ''