mirror of
				https://github.com/LonamiWebs/Telethon.git
				synced 2025-11-04 01:47:27 +03:00 
			
		
		
		
	Merge pull request #296 from LonamiWebs/entity-db
Use a custom database for entities
This commit is contained in:
		
						commit
						427a6aabaa
					
				| 
						 | 
				
			
			@ -1,7 +1,5 @@
 | 
			
		|||
import os
 | 
			
		||||
import re
 | 
			
		||||
from datetime import datetime, timedelta
 | 
			
		||||
from functools import lru_cache
 | 
			
		||||
from mimetypes import guess_type
 | 
			
		||||
 | 
			
		||||
try:
 | 
			
		||||
| 
						 | 
				
			
			@ -17,6 +15,7 @@ from .errors import (
 | 
			
		|||
)
 | 
			
		||||
from .network import ConnectionMode
 | 
			
		||||
from .tl import TLObject
 | 
			
		||||
from .tl.entity_database import EntityDatabase
 | 
			
		||||
from .tl.functions.account import (
 | 
			
		||||
    GetPasswordRequest
 | 
			
		||||
)
 | 
			
		||||
| 
						 | 
				
			
			@ -127,7 +126,7 @@ class TelegramClient(TelegramBareClient):
 | 
			
		|||
 | 
			
		||||
    def send_code_request(self, phone):
 | 
			
		||||
        """Sends a code request to the specified phone number"""
 | 
			
		||||
        phone = self._parse_phone(phone)
 | 
			
		||||
        phone = EntityDatabase.parse_phone(phone) or self._phone
 | 
			
		||||
        result = self(SendCodeRequest(phone, self.api_id, self.api_hash))
 | 
			
		||||
        self._phone = phone
 | 
			
		||||
        self._phone_code_hash = result.phone_code_hash
 | 
			
		||||
| 
						 | 
				
			
			@ -160,7 +159,7 @@ class TelegramClient(TelegramBareClient):
 | 
			
		|||
        if phone and not code:
 | 
			
		||||
            return self.send_code_request(phone)
 | 
			
		||||
        elif code:
 | 
			
		||||
            phone = self._parse_phone(phone)
 | 
			
		||||
            phone = EntityDatabase.parse_phone(phone) or self._phone
 | 
			
		||||
            phone_code_hash = phone_code_hash or self._phone_code_hash
 | 
			
		||||
            if not phone:
 | 
			
		||||
                raise ValueError(
 | 
			
		||||
| 
						 | 
				
			
			@ -849,7 +848,6 @@ class TelegramClient(TelegramBareClient):
 | 
			
		|||
 | 
			
		||||
    # region Small utilities to make users' life easier
 | 
			
		||||
 | 
			
		||||
    @lru_cache()
 | 
			
		||||
    def get_entity(self, entity):
 | 
			
		||||
        """Turns an entity into a valid Telegram user or chat.
 | 
			
		||||
           If "entity" is a string which can be converted to an integer,
 | 
			
		||||
| 
						 | 
				
			
			@ -866,71 +864,52 @@ class TelegramClient(TelegramBareClient):
 | 
			
		|||
           If the entity is neither, and it's not a TLObject, an
 | 
			
		||||
           error will be raised.
 | 
			
		||||
        """
 | 
			
		||||
        # TODO Maybe cache both the contacts and the entities.
 | 
			
		||||
        # If an user cannot be found, force a cache update through
 | 
			
		||||
        # a public method (since users may change their username)
 | 
			
		||||
        input_entity = None
 | 
			
		||||
        if isinstance(entity, TLObject):
 | 
			
		||||
        try:
 | 
			
		||||
            return self.session.entities[entity]
 | 
			
		||||
        except KeyError:
 | 
			
		||||
            pass
 | 
			
		||||
 | 
			
		||||
        if isinstance(entity, int) or (
 | 
			
		||||
                isinstance(entity, TLObject) and
 | 
			
		||||
                # crc32(b'InputPeer') and crc32(b'Peer')
 | 
			
		||||
            if type(entity).SUBCLASS_OF_ID in (0xc91c90b6, 0x2d45687):
 | 
			
		||||
                input_entity = self.get_input_entity(entity)
 | 
			
		||||
            else:
 | 
			
		||||
                # TODO Don't assume it's a valid entity
 | 
			
		||||
                return entity
 | 
			
		||||
                type(entity).SUBCLASS_OF_ID in (0xc91c90b6, 0x2d45687)):
 | 
			
		||||
            ie = self.get_input_entity(entity)
 | 
			
		||||
            if isinstance(ie, InputPeerUser):
 | 
			
		||||
                self.session.process_entities(GetUsersRequest([ie]))
 | 
			
		||||
            elif isinstance(ie, InputPeerChat):
 | 
			
		||||
                self.session.process_entities(GetChatsRequest([ie.chat_id]))
 | 
			
		||||
            elif isinstance(ie, InputPeerChannel):
 | 
			
		||||
                self.session.process_entities(GetChannelsRequest([ie]))
 | 
			
		||||
 | 
			
		||||
        elif isinstance(entity, int):
 | 
			
		||||
            input_entity = self.get_input_entity(entity)
 | 
			
		||||
 | 
			
		||||
        if input_entity:
 | 
			
		||||
            if isinstance(input_entity, InputPeerUser):
 | 
			
		||||
                return self(GetUsersRequest([input_entity]))[0]
 | 
			
		||||
            elif isinstance(input_entity, InputPeerChat):
 | 
			
		||||
                return self(GetChatsRequest([input_entity.chat_id])).chats[0]
 | 
			
		||||
            elif isinstance(input_entity, InputPeerChannel):
 | 
			
		||||
                return self(GetChannelsRequest([input_entity])).chats[0]
 | 
			
		||||
            # The result of Get*Request has been processed and the entity
 | 
			
		||||
            # cached if it was found.
 | 
			
		||||
            return self.session.entities[ie]
 | 
			
		||||
 | 
			
		||||
        if isinstance(entity, str):
 | 
			
		||||
            stripped_phone = self._parse_phone(entity, ignore_saved=True)
 | 
			
		||||
            if stripped_phone.isdigit():
 | 
			
		||||
                contacts = self(GetContactsRequest(0))
 | 
			
		||||
                try:
 | 
			
		||||
                    return next(
 | 
			
		||||
                        u for u in contacts.users
 | 
			
		||||
                        if u.phone and u.phone.endswith(stripped_phone)
 | 
			
		||||
                    )
 | 
			
		||||
                except StopIteration:
 | 
			
		||||
                    raise ValueError(
 | 
			
		||||
                        'Could not find user with phone {}, '
 | 
			
		||||
                        'add them to your contacts first'.format(entity)
 | 
			
		||||
                    )
 | 
			
		||||
            else:
 | 
			
		||||
                username = entity.strip('@').lower()
 | 
			
		||||
                resolved = self(ResolveUsernameRequest(username))
 | 
			
		||||
                for c in resolved.chats:
 | 
			
		||||
                    if getattr(c, 'username', '').lower() == username:
 | 
			
		||||
                        return c
 | 
			
		||||
                for u in resolved.users:
 | 
			
		||||
                    if getattr(u, 'username', '').lower() == username:
 | 
			
		||||
                        return u
 | 
			
		||||
 | 
			
		||||
                raise ValueError(
 | 
			
		||||
                    'Could not find user with username {}'.format(entity)
 | 
			
		||||
                )
 | 
			
		||||
            return self._get_entity_from_string(entity)
 | 
			
		||||
 | 
			
		||||
        raise ValueError(
 | 
			
		||||
            'Cannot turn "{}" into any entity (user or chat)'.format(entity)
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    def _parse_phone(self, phone, ignore_saved=False):
 | 
			
		||||
        if isinstance(phone, int):
 | 
			
		||||
            phone = str(phone)
 | 
			
		||||
        elif phone:
 | 
			
		||||
            phone = re.sub(r'[+()\s-]', '', phone)
 | 
			
		||||
 | 
			
		||||
        if ignore_saved:
 | 
			
		||||
            return phone
 | 
			
		||||
    def _get_entity_from_string(self, string):
 | 
			
		||||
        """Gets an entity from the given string, which may be a phone or
 | 
			
		||||
           an username, and processes all the found entities on the session.
 | 
			
		||||
        """
 | 
			
		||||
        phone = EntityDatabase.parse_phone(string)
 | 
			
		||||
        if phone:
 | 
			
		||||
            entity = phone
 | 
			
		||||
            self.session.process_entities(self(GetContactsRequest(0)))
 | 
			
		||||
        else:
 | 
			
		||||
            return phone or self._phone
 | 
			
		||||
            entity = string.strip('@').lower()
 | 
			
		||||
            self.session.process_entities(self(ResolveUsernameRequest(entity)))
 | 
			
		||||
 | 
			
		||||
        try:
 | 
			
		||||
            return self.session.entities[entity]
 | 
			
		||||
        except KeyError:
 | 
			
		||||
            raise ValueError(
 | 
			
		||||
                'Could not find user with username {}'.format(entity)
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
    def get_input_entity(self, peer):
 | 
			
		||||
        """Gets the input entity given its PeerUser, PeerChat, PeerChannel.
 | 
			
		||||
| 
						 | 
				
			
			@ -940,11 +919,16 @@ class TelegramClient(TelegramBareClient):
 | 
			
		|||
           If this Peer hasn't been seen before by the library, all dialogs
 | 
			
		||||
           will loaded, and their entities saved to the session file.
 | 
			
		||||
 | 
			
		||||
           If even after
 | 
			
		||||
           If even after it's not found, a ValueError is raised.
 | 
			
		||||
        """
 | 
			
		||||
        try:
 | 
			
		||||
            # First try to get the entity from cache, otherwise figure it out
 | 
			
		||||
            self.session.entities.get_input_entity(peer)
 | 
			
		||||
        except KeyError:
 | 
			
		||||
            pass
 | 
			
		||||
 | 
			
		||||
        if isinstance(peer, str):
 | 
			
		||||
            # Let .get_entity resolve the username or phone (full entity)
 | 
			
		||||
            peer = self.get_entity(peer)
 | 
			
		||||
            return utils.get_input_peer(self._get_entity_from_string(peer))
 | 
			
		||||
 | 
			
		||||
        is_peer = False
 | 
			
		||||
        if isinstance(peer, int):
 | 
			
		||||
| 
						 | 
				
			
			@ -964,16 +948,12 @@ class TelegramClient(TelegramBareClient):
 | 
			
		|||
                'Cannot turn "{}" into an input entity.'.format(peer)
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        try:
 | 
			
		||||
            return self.session.get_input_entity(peer)
 | 
			
		||||
        except KeyError:
 | 
			
		||||
            pass
 | 
			
		||||
 | 
			
		||||
        if self.session.save_entities:
 | 
			
		||||
            # Not found, look in the dialogs (this will save the users)
 | 
			
		||||
            self.get_dialogs(limit=None)
 | 
			
		||||
 | 
			
		||||
            try:
 | 
			
		||||
                return self.session.get_input_entity(peer)
 | 
			
		||||
                self.session.entities.get_input_entity(peer)
 | 
			
		||||
            except KeyError:
 | 
			
		||||
                pass
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
							
								
								
									
										188
									
								
								telethon/tl/entity_database.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										188
									
								
								telethon/tl/entity_database.py
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
				
			
			@ -0,0 +1,188 @@
 | 
			
		|||
from threading import Lock
 | 
			
		||||
 | 
			
		||||
import re
 | 
			
		||||
 | 
			
		||||
from .. import utils
 | 
			
		||||
from ..tl import TLObject
 | 
			
		||||
from ..tl.types import User, Chat, Channel
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class EntityDatabase:
 | 
			
		||||
    def __init__(self, input_list=None, enabled=True, enabled_full=True):
 | 
			
		||||
        """Creates a new entity database with an initial load of "Input"
 | 
			
		||||
           entities, if any.
 | 
			
		||||
 | 
			
		||||
           If 'enabled', input entities will be saved. The whole entity
 | 
			
		||||
           will be saved if both 'enabled' and 'enabled_full' are True.
 | 
			
		||||
        """
 | 
			
		||||
        self.enabled = enabled
 | 
			
		||||
        self.enabled_full = enabled_full
 | 
			
		||||
 | 
			
		||||
        self._lock = Lock()
 | 
			
		||||
        self._entities = {}  # marked_id: user|chat|channel
 | 
			
		||||
 | 
			
		||||
        if input_list:
 | 
			
		||||
            self._input_entities = {k: v for k, v in input_list}
 | 
			
		||||
        else:
 | 
			
		||||
            self._input_entities = {}  # marked_id: hash
 | 
			
		||||
 | 
			
		||||
        # TODO Allow disabling some extra mappings
 | 
			
		||||
        self._username_id = {}  # username: marked_id
 | 
			
		||||
        self._phone_id = {}  # phone: marked_id
 | 
			
		||||
 | 
			
		||||
    def process(self, tlobject):
 | 
			
		||||
        """Processes all the found entities on the given TLObject,
 | 
			
		||||
           unless .enabled is False.
 | 
			
		||||
 | 
			
		||||
           Returns True if new input entities were added.
 | 
			
		||||
        """
 | 
			
		||||
        if not self.enabled:
 | 
			
		||||
            return False
 | 
			
		||||
 | 
			
		||||
        # Save all input entities we know of
 | 
			
		||||
        if not isinstance(tlobject, TLObject) and hasattr(tlobject, '__iter__'):
 | 
			
		||||
            # This may be a list of users already for instance
 | 
			
		||||
            return self.expand(tlobject)
 | 
			
		||||
 | 
			
		||||
        entities = []
 | 
			
		||||
        if hasattr(tlobject, 'chats') and hasattr(tlobject.chats, '__iter__'):
 | 
			
		||||
            entities.extend(tlobject.chats)
 | 
			
		||||
        if hasattr(tlobject, 'users') and hasattr(tlobject.users, '__iter__'):
 | 
			
		||||
            entities.extend(tlobject.users)
 | 
			
		||||
 | 
			
		||||
        return self.expand(entities)
 | 
			
		||||
 | 
			
		||||
    def expand(self, entities):
 | 
			
		||||
        """Adds new input entities to the local database unconditionally.
 | 
			
		||||
           Unknown types will be ignored.
 | 
			
		||||
        """
 | 
			
		||||
        if not entities or not self.enabled:
 | 
			
		||||
            return False
 | 
			
		||||
 | 
			
		||||
        new = []  # Array of entities (User, Chat, or Channel)
 | 
			
		||||
        new_input = {}  # Dictionary of {entity_marked_id: access_hash}
 | 
			
		||||
        for e in entities:
 | 
			
		||||
            if not isinstance(e, TLObject):
 | 
			
		||||
                continue
 | 
			
		||||
 | 
			
		||||
            try:
 | 
			
		||||
                p = utils.get_input_peer(e, allow_self=False)
 | 
			
		||||
                new_input[utils.get_peer_id(p, add_mark=True)] = \
 | 
			
		||||
                    getattr(p, 'access_hash', 0)  # chats won't have hash
 | 
			
		||||
 | 
			
		||||
                if self.enabled_full:
 | 
			
		||||
                    if isinstance(e, User) \
 | 
			
		||||
                            or isinstance(e, Chat) \
 | 
			
		||||
                            or isinstance(e, Channel):
 | 
			
		||||
                        new.append(e)
 | 
			
		||||
            except ValueError:
 | 
			
		||||
                pass
 | 
			
		||||
 | 
			
		||||
        with self._lock:
 | 
			
		||||
            before = len(self._input_entities)
 | 
			
		||||
            self._input_entities.update(new_input)
 | 
			
		||||
            for e in new:
 | 
			
		||||
                self._add_full_entity(e)
 | 
			
		||||
            return len(self._input_entities) != before
 | 
			
		||||
 | 
			
		||||
    def _add_full_entity(self, entity):
 | 
			
		||||
        """Adds a "full" entity (User, Chat or Channel, not "Input*"),
 | 
			
		||||
           despite the value of self.enabled and self.enabled_full.
 | 
			
		||||
 | 
			
		||||
           Not to be confused with UserFull, ChatFull, or ChannelFull,
 | 
			
		||||
           "full" means simply not "Input*".
 | 
			
		||||
        """
 | 
			
		||||
        marked_id = utils.get_peer_id(
 | 
			
		||||
            utils.get_input_peer(entity, allow_self=False), add_mark=True
 | 
			
		||||
        )
 | 
			
		||||
        try:
 | 
			
		||||
            old_entity = self._entities[marked_id]
 | 
			
		||||
            old_entity.__dict__.update(entity.__dict__)  # Keep old references
 | 
			
		||||
 | 
			
		||||
            # Update must delete old username and phone
 | 
			
		||||
            username = getattr(old_entity, 'username', None)
 | 
			
		||||
            if username:
 | 
			
		||||
                del self._username_id[username.lower()]
 | 
			
		||||
 | 
			
		||||
            phone = getattr(old_entity, 'phone', None)
 | 
			
		||||
            if phone:
 | 
			
		||||
                del self._phone_id[phone]
 | 
			
		||||
        except KeyError:
 | 
			
		||||
            # Add new entity
 | 
			
		||||
            self._entities[marked_id] = entity
 | 
			
		||||
 | 
			
		||||
        # Always update username or phone if any
 | 
			
		||||
        username = getattr(entity, 'username', None)
 | 
			
		||||
        if username:
 | 
			
		||||
            self._username_id[username.lower()] = marked_id
 | 
			
		||||
 | 
			
		||||
        phone = getattr(entity, 'phone', None)
 | 
			
		||||
        if phone:
 | 
			
		||||
            self._username_id[phone] = marked_id
 | 
			
		||||
 | 
			
		||||
    def __getitem__(self, key):
 | 
			
		||||
        """Accepts a digit only string as phone number,
 | 
			
		||||
           otherwise it's treated as an username.
 | 
			
		||||
 | 
			
		||||
           If an integer is given, it's treated as the ID of the desired User.
 | 
			
		||||
           The ID given won't try to be guessed as the ID of a chat or channel,
 | 
			
		||||
           as there may be an user with that ID, and it would be unreliable.
 | 
			
		||||
 | 
			
		||||
           If a Peer is given (PeerUser, PeerChat, PeerChannel),
 | 
			
		||||
           its specific entity is retrieved as User, Chat or Channel.
 | 
			
		||||
           Note that megagroups are channels with .megagroup = True.
 | 
			
		||||
        """
 | 
			
		||||
        if isinstance(key, str):
 | 
			
		||||
            phone = EntityDatabase.parse_phone(key)
 | 
			
		||||
            if phone:
 | 
			
		||||
                return self._phone_id[phone]
 | 
			
		||||
            else:
 | 
			
		||||
                key = key.lstrip('@').lower()
 | 
			
		||||
                return self._entities[self._username_id[key]]
 | 
			
		||||
 | 
			
		||||
        if isinstance(key, int):
 | 
			
		||||
            return self._entities[key]  # normal IDs are assumed users
 | 
			
		||||
 | 
			
		||||
        if isinstance(key, TLObject):
 | 
			
		||||
            sc = type(key).SUBCLASS_OF_ID
 | 
			
		||||
            if sc == 0x2d45687:
 | 
			
		||||
                # Subclass of "Peer"
 | 
			
		||||
                return self._entities[utils.get_peer_id(key, add_mark=True)]
 | 
			
		||||
            elif sc in {0x2da17977, 0xc5af5d94, 0x6d44b7db}:
 | 
			
		||||
                # Subclass of "User", "Chat" or "Channel"
 | 
			
		||||
                return key
 | 
			
		||||
 | 
			
		||||
        raise KeyError(key)
 | 
			
		||||
 | 
			
		||||
    def __delitem__(self, key):
 | 
			
		||||
        target = self[key]
 | 
			
		||||
        del self._entities[key]
 | 
			
		||||
        if getattr(target, 'username'):
 | 
			
		||||
            del self._username_id[target.username]
 | 
			
		||||
 | 
			
		||||
    # TODO Allow search by name by tokenizing the input and return a list
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def parse_phone(phone):
 | 
			
		||||
        """Parses the given phone, or returns None if it's invalid"""
 | 
			
		||||
        if isinstance(phone, int):
 | 
			
		||||
            return str(phone)
 | 
			
		||||
        else:
 | 
			
		||||
            phone = re.sub(r'[+()\s-]', '', str(phone))
 | 
			
		||||
            if phone.isdigit():
 | 
			
		||||
                return phone
 | 
			
		||||
 | 
			
		||||
    def get_input_entity(self, peer):
 | 
			
		||||
        try:
 | 
			
		||||
            return self._input_entities[utils.get_peer_id(peer, add_mark=True)]
 | 
			
		||||
        except ValueError as e:
 | 
			
		||||
            raise KeyError(peer) from e
 | 
			
		||||
 | 
			
		||||
    def get_input_list(self):
 | 
			
		||||
        return list(self._input_entities.items())
 | 
			
		||||
 | 
			
		||||
    def clear(self, target=None):
 | 
			
		||||
        if target is None:
 | 
			
		||||
            self._entities.clear()
 | 
			
		||||
        else:
 | 
			
		||||
            del self[target]
 | 
			
		||||
| 
						 | 
				
			
			@ -6,11 +6,8 @@ from base64 import b64encode, b64decode
 | 
			
		|||
from os.path import isfile as file_exists
 | 
			
		||||
from threading import Lock
 | 
			
		||||
 | 
			
		||||
from .. import helpers, utils
 | 
			
		||||
from ..tl.types import (
 | 
			
		||||
    InputPeerUser, InputPeerChat, InputPeerChannel,
 | 
			
		||||
    PeerUser, PeerChat, PeerChannel
 | 
			
		||||
)
 | 
			
		||||
from .entity_database import EntityDatabase
 | 
			
		||||
from .. import helpers
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class Session:
 | 
			
		||||
| 
						 | 
				
			
			@ -70,8 +67,7 @@ class Session:
 | 
			
		|||
        self.auth_key = None
 | 
			
		||||
        self.layer = 0
 | 
			
		||||
        self.salt = 0  # Unsigned long
 | 
			
		||||
        self._input_entities = {}  # {marked_id: hash}
 | 
			
		||||
        self._entities_lock = Lock()
 | 
			
		||||
        self.entities = EntityDatabase()  # Known and cached entities
 | 
			
		||||
 | 
			
		||||
    def save(self):
 | 
			
		||||
        """Saves the current session object as session_user_id.session"""
 | 
			
		||||
| 
						 | 
				
			
			@ -90,7 +86,7 @@ class Session:
 | 
			
		|||
                        if self.auth_key else None
 | 
			
		||||
                }
 | 
			
		||||
                if self.save_entities:
 | 
			
		||||
                    out_dict['entities'] = list(self._input_entities.items())
 | 
			
		||||
                    out_dict['entities'] = self.entities.get_input_list()
 | 
			
		||||
 | 
			
		||||
                json.dump(out_dict, file)
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -139,8 +135,7 @@ class Session:
 | 
			
		|||
                        key = b64decode(data['auth_key_data'])
 | 
			
		||||
                        result.auth_key = AuthKey(data=key)
 | 
			
		||||
 | 
			
		||||
                    for e_mid, e_hash in data.get('entities', []):
 | 
			
		||||
                        result._input_entities[e_mid] = e_hash
 | 
			
		||||
                    result.entities = EntityDatabase(data.get('entities', []))
 | 
			
		||||
 | 
			
		||||
            except (json.decoder.JSONDecodeError, UnicodeDecodeError):
 | 
			
		||||
                pass
 | 
			
		||||
| 
						 | 
				
			
			@ -186,58 +181,8 @@ class Session:
 | 
			
		|||
        self.time_offset = correct - now
 | 
			
		||||
 | 
			
		||||
    def process_entities(self, tlobject):
 | 
			
		||||
        """Processes all the found entities on the given TLObject,
 | 
			
		||||
           unless .save_entities is False, and saves the session file.
 | 
			
		||||
        """
 | 
			
		||||
        if not self.save_entities:
 | 
			
		||||
            return
 | 
			
		||||
 | 
			
		||||
        # Save all input entities we know of
 | 
			
		||||
        entities = []
 | 
			
		||||
        if hasattr(tlobject, 'chats') and hasattr(tlobject.chats, '__iter__'):
 | 
			
		||||
            entities.extend(tlobject.chats)
 | 
			
		||||
        if hasattr(tlobject, 'users') and hasattr(tlobject.users, '__iter__'):
 | 
			
		||||
            entities.extend(tlobject.users)
 | 
			
		||||
 | 
			
		||||
        if self.add_entities(entities):
 | 
			
		||||
            self.save()  # Save if any new entities got added
 | 
			
		||||
 | 
			
		||||
    def add_entities(self, entities):
 | 
			
		||||
        """Adds new input entities to the local database unconditionally.
 | 
			
		||||
           Unknown types will be ignored.
 | 
			
		||||
        """
 | 
			
		||||
        if not entities:
 | 
			
		||||
            return False
 | 
			
		||||
 | 
			
		||||
        new = {}
 | 
			
		||||
        for e in entities:
 | 
			
		||||
        try:
 | 
			
		||||
                p = utils.get_input_peer(e)
 | 
			
		||||
                new[utils.get_peer_id(p, add_mark=True)] = \
 | 
			
		||||
                    getattr(p, 'access_hash', 0)  # chats won't have hash
 | 
			
		||||
            except ValueError:
 | 
			
		||||
            if self.entities.process(tlobject):
 | 
			
		||||
                self.save()  # Save if any new entities got added
 | 
			
		||||
        except:
 | 
			
		||||
            pass
 | 
			
		||||
 | 
			
		||||
        with self._entities_lock:
 | 
			
		||||
            before = len(self._input_entities)
 | 
			
		||||
            self._input_entities.update(new)
 | 
			
		||||
            return len(self._input_entities) != before
 | 
			
		||||
 | 
			
		||||
    def get_input_entity(self, peer):
 | 
			
		||||
        """Gets an input entity known its Peer or a marked ID,
 | 
			
		||||
           or raises KeyError if not found/invalid.
 | 
			
		||||
        """
 | 
			
		||||
        if not isinstance(peer, int):
 | 
			
		||||
            peer = utils.get_peer_id(peer, add_mark=True)
 | 
			
		||||
 | 
			
		||||
        entity_hash = self._input_entities[peer]
 | 
			
		||||
        entity_id, peer_class = utils.resolve_id(peer)
 | 
			
		||||
 | 
			
		||||
        if peer_class == PeerUser:
 | 
			
		||||
            return InputPeerUser(entity_id, entity_hash)
 | 
			
		||||
        if peer_class == PeerChat:
 | 
			
		||||
            return InputPeerChat(entity_id)
 | 
			
		||||
        if peer_class == PeerChannel:
 | 
			
		||||
            return InputPeerChannel(entity_id, entity_hash)
 | 
			
		||||
 | 
			
		||||
        raise KeyError()
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -72,7 +72,7 @@ def _raise_cast_fail(entity, target):
 | 
			
		|||
                     .format(type(entity).__name__, target))
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_input_peer(entity):
 | 
			
		||||
def get_input_peer(entity, allow_self=True):
 | 
			
		||||
    """Gets the input peer for the given "entity" (user, chat or channel).
 | 
			
		||||
       A ValueError is raised if the given entity isn't a supported type."""
 | 
			
		||||
    if not isinstance(entity, TLObject):
 | 
			
		||||
| 
						 | 
				
			
			@ -82,7 +82,7 @@ def get_input_peer(entity):
 | 
			
		|||
        return entity
 | 
			
		||||
 | 
			
		||||
    if isinstance(entity, User):
 | 
			
		||||
        if entity.is_self:
 | 
			
		||||
        if entity.is_self and allow_self:
 | 
			
		||||
            return InputPeerSelf()
 | 
			
		||||
        else:
 | 
			
		||||
            return InputPeerUser(entity.id, entity.access_hash)
 | 
			
		||||
| 
						 | 
				
			
			@ -304,6 +304,16 @@ def get_peer_id(peer, add_mark=False):
 | 
			
		|||
    """Finds the ID of the given peer, and optionally converts it to
 | 
			
		||||
       the "bot api" format if 'add_mark' is set to True.
 | 
			
		||||
    """
 | 
			
		||||
    if not isinstance(peer, TLObject):
 | 
			
		||||
        if isinstance(peer, int):
 | 
			
		||||
            return peer
 | 
			
		||||
        else:
 | 
			
		||||
            _raise_cast_fail(peer, 'int')
 | 
			
		||||
 | 
			
		||||
    if type(peer).SUBCLASS_OF_ID not in {0x2d45687, 0xc91c90b6}:
 | 
			
		||||
        # Not a Peer or an InputPeer, so first get its Input version
 | 
			
		||||
        peer = get_input_peer(peer, allow_self=False)
 | 
			
		||||
 | 
			
		||||
    if isinstance(peer, PeerUser) or isinstance(peer, InputPeerUser):
 | 
			
		||||
        return peer.user_id
 | 
			
		||||
    else:
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue
	
	Block a user