Fix utils._get_extension not working in pathlib objects

This was found while testing #1371.
This commit is contained in:
Lonami Exo 2020-01-17 11:11:10 +01:00
parent d09f6a50b0
commit 78ee787310
2 changed files with 40 additions and 0 deletions

View File

@ -12,6 +12,7 @@ import logging
import math import math
import mimetypes import mimetypes
import os import os
import pathlib
import re import re
import struct import struct
from collections import namedtuple from collections import namedtuple
@ -741,6 +742,8 @@ def _get_extension(file):
""" """
if isinstance(file, str): if isinstance(file, str):
return os.path.splitext(file)[-1] return os.path.splitext(file)[-1]
elif isinstance(file, pathlib.Path):
return file.suffix
elif isinstance(file, bytes): elif isinstance(file, bytes):
kind = imghdr.what(io.BytesIO(file)) kind = imghdr.what(io.BytesIO(file))
return ('.' + kind) if kind else '' return ('.' + kind) if kind else ''

View File

@ -1,3 +1,6 @@
import io
import pathlib
from telethon import utils from telethon import utils
from telethon.tl.types import ( from telethon.tl.types import (
MessageMediaGame, Game, PhotoEmpty MessageMediaGame, Game, PhotoEmpty
@ -16,3 +19,37 @@ def test_game_input_media_memory_error():
)) ))
input_media = utils.get_input_media(media) input_media = utils.get_input_media(media)
bytes(input_media) # <- shouldn't raise `MemoryError` bytes(input_media) # <- shouldn't raise `MemoryError`
def test_private_get_extension():
# Positive cases
png_header = bytes.fromhex('89 50 4e 47 0d 0a 1a 0a 00 00 00 0d 49 48 44 52')
png_buffer = io.BytesIO(png_header)
class CustomFd:
def __init__(self, name):
self.name = name
assert utils._get_extension('foo.bar.baz') == '.baz'
assert utils._get_extension(pathlib.Path('foo.bar.baz')) == '.baz'
assert utils._get_extension(png_header) == '.png'
assert utils._get_extension(png_buffer) == '.png'
assert utils._get_extension(png_buffer) == '.png' # make sure it did seek back
assert utils._get_extension(CustomFd('foo.bar.baz')) == '.baz'
# Negative cases
null_header = bytes.fromhex('00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00')
null_buffer = io.BytesIO(null_header)
empty_header = bytes()
empty_buffer = io.BytesIO(empty_header)
assert utils._get_extension('foo') == ''
assert utils._get_extension(pathlib.Path('foo')) == ''
assert utils._get_extension(null_header) == ''
assert utils._get_extension(null_buffer) == ''
assert utils._get_extension(null_buffer) == '' # make sure it did seek back
assert utils._get_extension(empty_header) == ''
assert utils._get_extension(empty_buffer) == ''
assert utils._get_extension(empty_buffer) == '' # make sure it did seek back
assert utils._get_extension(CustomFd('foo')) == ''