diff --git a/telethon/utils.py b/telethon/utils.py index 8b12759f..0c63d143 100644 --- a/telethon/utils.py +++ b/telethon/utils.py @@ -12,6 +12,7 @@ import logging import math import mimetypes import os +import pathlib import re import struct from collections import namedtuple @@ -741,6 +742,8 @@ def _get_extension(file): """ if isinstance(file, str): return os.path.splitext(file)[-1] + elif isinstance(file, pathlib.Path): + return file.suffix elif isinstance(file, bytes): kind = imghdr.what(io.BytesIO(file)) return ('.' + kind) if kind else '' diff --git a/tests/telethon/test_utils.py b/tests/telethon/test_utils.py index d6f97e68..89391e36 100644 --- a/tests/telethon/test_utils.py +++ b/tests/telethon/test_utils.py @@ -1,3 +1,6 @@ +import io +import pathlib + from telethon import utils from telethon.tl.types import ( MessageMediaGame, Game, PhotoEmpty @@ -16,3 +19,37 @@ def test_game_input_media_memory_error(): )) input_media = utils.get_input_media(media) bytes(input_media) # <- shouldn't raise `MemoryError` + + +def test_private_get_extension(): + # Positive cases + png_header = bytes.fromhex('89 50 4e 47 0d 0a 1a 0a 00 00 00 0d 49 48 44 52') + png_buffer = io.BytesIO(png_header) + + class CustomFd: + def __init__(self, name): + self.name = name + + assert utils._get_extension('foo.bar.baz') == '.baz' + assert utils._get_extension(pathlib.Path('foo.bar.baz')) == '.baz' + assert utils._get_extension(png_header) == '.png' + assert utils._get_extension(png_buffer) == '.png' + assert utils._get_extension(png_buffer) == '.png' # make sure it did seek back + assert utils._get_extension(CustomFd('foo.bar.baz')) == '.baz' + + # Negative cases + null_header = bytes.fromhex('00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00') + null_buffer = io.BytesIO(null_header) + + empty_header = bytes() + empty_buffer = io.BytesIO(empty_header) + + assert utils._get_extension('foo') == '' + assert utils._get_extension(pathlib.Path('foo')) == '' + assert utils._get_extension(null_header) == '' + assert utils._get_extension(null_buffer) == '' + assert utils._get_extension(null_buffer) == '' # make sure it did seek back + assert utils._get_extension(empty_header) == '' + assert utils._get_extension(empty_buffer) == '' + assert utils._get_extension(empty_buffer) == '' # make sure it did seek back + assert utils._get_extension(CustomFd('foo')) == ''