Make get_input_* methods slightly smarter

This commit is contained in:
Lonami Exo 2017-07-10 16:09:20 +02:00
parent bdee94eaf3
commit 88c4cdfb52

View File

@ -6,9 +6,9 @@ from mimetypes import add_type, guess_extension
from .tl.types import (
Channel, ChannelForbidden, Chat, ChatEmpty, ChatForbidden, ChatFull,
ChatPhoto, InputPeerChannel, InputPeerChat, InputPeerUser,
ChatPhoto, InputPeerChannel, InputPeerChat, InputPeerUser, InputPeerEmpty,
MessageMediaDocument, MessageMediaPhoto, PeerChannel, InputChannel,
UserEmpty, InputUser, InputUserEmpty, InputUserSelf,
UserEmpty, InputUser, InputUserEmpty, InputUserSelf, InputPeerSelf,
PeerChat, PeerUser, User, UserFull, UserProfilePhoto)
@ -57,6 +57,9 @@ def get_input_peer(entity):
return entity
if isinstance(entity, User):
if entity.is_self:
return InputPeerSelf()
else:
return InputPeerUser(entity.id, entity.access_hash)
if any(isinstance(entity, c) for c in (
@ -68,8 +71,14 @@ def get_input_peer(entity):
return InputPeerChannel(entity.id, entity.access_hash)
# Less common cases
if isinstance(entity, UserEmpty):
return InputPeerEmpty()
if isinstance(entity, InputUser):
return InputPeerUser(entity.user_id, entity.access_hash)
if isinstance(entity, UserFull):
return InputPeerUser(entity.user.id, entity.user.access_hash)
return get_input_peer(entity.user)
if isinstance(entity, ChatFull):
return InputPeerChat(entity.id)
@ -110,6 +119,9 @@ def get_input_user(entity):
if isinstance(entity, UserFull):
return get_input_user(entity.user)
if isinstance(entity, InputPeerUser):
return InputUser(entity.user_id, entity.access_hash)
raise ValueError('Cannot cast {} to any kind of InputUser.'
.format(type(entity).__name__))