Merge pull request #296 from LonamiWebs/entity-db

Use a custom database for entities
This commit is contained in:
Lonami 2017-10-05 13:46:16 +02:00 committed by GitHub
commit 427a6aabaa
4 changed files with 259 additions and 136 deletions

View File

@ -1,7 +1,5 @@
import os import os
import re
from datetime import datetime, timedelta from datetime import datetime, timedelta
from functools import lru_cache
from mimetypes import guess_type from mimetypes import guess_type
try: try:
@ -17,6 +15,7 @@ from .errors import (
) )
from .network import ConnectionMode from .network import ConnectionMode
from .tl import TLObject from .tl import TLObject
from .tl.entity_database import EntityDatabase
from .tl.functions.account import ( from .tl.functions.account import (
GetPasswordRequest GetPasswordRequest
) )
@ -127,7 +126,7 @@ class TelegramClient(TelegramBareClient):
def send_code_request(self, phone): def send_code_request(self, phone):
"""Sends a code request to the specified phone number""" """Sends a code request to the specified phone number"""
phone = self._parse_phone(phone) phone = EntityDatabase.parse_phone(phone) or self._phone
result = self(SendCodeRequest(phone, self.api_id, self.api_hash)) result = self(SendCodeRequest(phone, self.api_id, self.api_hash))
self._phone = phone self._phone = phone
self._phone_code_hash = result.phone_code_hash self._phone_code_hash = result.phone_code_hash
@ -160,7 +159,7 @@ class TelegramClient(TelegramBareClient):
if phone and not code: if phone and not code:
return self.send_code_request(phone) return self.send_code_request(phone)
elif code: elif code:
phone = self._parse_phone(phone) phone = EntityDatabase.parse_phone(phone) or self._phone
phone_code_hash = phone_code_hash or self._phone_code_hash phone_code_hash = phone_code_hash or self._phone_code_hash
if not phone: if not phone:
raise ValueError( raise ValueError(
@ -849,7 +848,6 @@ class TelegramClient(TelegramBareClient):
# region Small utilities to make users' life easier # region Small utilities to make users' life easier
@lru_cache()
def get_entity(self, entity): def get_entity(self, entity):
"""Turns an entity into a valid Telegram user or chat. """Turns an entity into a valid Telegram user or chat.
If "entity" is a string which can be converted to an integer, If "entity" is a string which can be converted to an integer,
@ -866,71 +864,52 @@ class TelegramClient(TelegramBareClient):
If the entity is neither, and it's not a TLObject, an If the entity is neither, and it's not a TLObject, an
error will be raised. error will be raised.
""" """
# TODO Maybe cache both the contacts and the entities. try:
# If an user cannot be found, force a cache update through return self.session.entities[entity]
# a public method (since users may change their username) except KeyError:
input_entity = None pass
if isinstance(entity, TLObject):
# crc32(b'InputPeer') and crc32(b'Peer')
if type(entity).SUBCLASS_OF_ID in (0xc91c90b6, 0x2d45687):
input_entity = self.get_input_entity(entity)
else:
# TODO Don't assume it's a valid entity
return entity
elif isinstance(entity, int): if isinstance(entity, int) or (
input_entity = self.get_input_entity(entity) isinstance(entity, TLObject) and
# crc32(b'InputPeer') and crc32(b'Peer')
type(entity).SUBCLASS_OF_ID in (0xc91c90b6, 0x2d45687)):
ie = self.get_input_entity(entity)
if isinstance(ie, InputPeerUser):
self.session.process_entities(GetUsersRequest([ie]))
elif isinstance(ie, InputPeerChat):
self.session.process_entities(GetChatsRequest([ie.chat_id]))
elif isinstance(ie, InputPeerChannel):
self.session.process_entities(GetChannelsRequest([ie]))
if input_entity: # The result of Get*Request has been processed and the entity
if isinstance(input_entity, InputPeerUser): # cached if it was found.
return self(GetUsersRequest([input_entity]))[0] return self.session.entities[ie]
elif isinstance(input_entity, InputPeerChat):
return self(GetChatsRequest([input_entity.chat_id])).chats[0]
elif isinstance(input_entity, InputPeerChannel):
return self(GetChannelsRequest([input_entity])).chats[0]
if isinstance(entity, str): if isinstance(entity, str):
stripped_phone = self._parse_phone(entity, ignore_saved=True) return self._get_entity_from_string(entity)
if stripped_phone.isdigit():
contacts = self(GetContactsRequest(0))
try:
return next(
u for u in contacts.users
if u.phone and u.phone.endswith(stripped_phone)
)
except StopIteration:
raise ValueError(
'Could not find user with phone {}, '
'add them to your contacts first'.format(entity)
)
else:
username = entity.strip('@').lower()
resolved = self(ResolveUsernameRequest(username))
for c in resolved.chats:
if getattr(c, 'username', '').lower() == username:
return c
for u in resolved.users:
if getattr(u, 'username', '').lower() == username:
return u
raise ValueError(
'Could not find user with username {}'.format(entity)
)
raise ValueError( raise ValueError(
'Cannot turn "{}" into any entity (user or chat)'.format(entity) 'Cannot turn "{}" into any entity (user or chat)'.format(entity)
) )
def _parse_phone(self, phone, ignore_saved=False): def _get_entity_from_string(self, string):
if isinstance(phone, int): """Gets an entity from the given string, which may be a phone or
phone = str(phone) an username, and processes all the found entities on the session.
elif phone: """
phone = re.sub(r'[+()\s-]', '', phone) phone = EntityDatabase.parse_phone(string)
if phone:
if ignore_saved: entity = phone
return phone self.session.process_entities(self(GetContactsRequest(0)))
else: else:
return phone or self._phone entity = string.strip('@').lower()
self.session.process_entities(self(ResolveUsernameRequest(entity)))
try:
return self.session.entities[entity]
except KeyError:
raise ValueError(
'Could not find user with username {}'.format(entity)
)
def get_input_entity(self, peer): def get_input_entity(self, peer):
"""Gets the input entity given its PeerUser, PeerChat, PeerChannel. """Gets the input entity given its PeerUser, PeerChat, PeerChannel.
@ -940,11 +919,16 @@ class TelegramClient(TelegramBareClient):
If this Peer hasn't been seen before by the library, all dialogs If this Peer hasn't been seen before by the library, all dialogs
will loaded, and their entities saved to the session file. will loaded, and their entities saved to the session file.
If even after If even after it's not found, a ValueError is raised.
""" """
try:
# First try to get the entity from cache, otherwise figure it out
self.session.entities.get_input_entity(peer)
except KeyError:
pass
if isinstance(peer, str): if isinstance(peer, str):
# Let .get_entity resolve the username or phone (full entity) return utils.get_input_peer(self._get_entity_from_string(peer))
peer = self.get_entity(peer)
is_peer = False is_peer = False
if isinstance(peer, int): if isinstance(peer, int):
@ -964,16 +948,12 @@ class TelegramClient(TelegramBareClient):
'Cannot turn "{}" into an input entity.'.format(peer) 'Cannot turn "{}" into an input entity.'.format(peer)
) )
try:
return self.session.get_input_entity(peer)
except KeyError:
pass
if self.session.save_entities: if self.session.save_entities:
# Not found, look in the dialogs (this will save the users) # Not found, look in the dialogs (this will save the users)
self.get_dialogs(limit=None) self.get_dialogs(limit=None)
try: try:
return self.session.get_input_entity(peer) self.session.entities.get_input_entity(peer)
except KeyError: except KeyError:
pass pass

View File

@ -0,0 +1,188 @@
from threading import Lock
import re
from .. import utils
from ..tl import TLObject
from ..tl.types import User, Chat, Channel
class EntityDatabase:
def __init__(self, input_list=None, enabled=True, enabled_full=True):
"""Creates a new entity database with an initial load of "Input"
entities, if any.
If 'enabled', input entities will be saved. The whole entity
will be saved if both 'enabled' and 'enabled_full' are True.
"""
self.enabled = enabled
self.enabled_full = enabled_full
self._lock = Lock()
self._entities = {} # marked_id: user|chat|channel
if input_list:
self._input_entities = {k: v for k, v in input_list}
else:
self._input_entities = {} # marked_id: hash
# TODO Allow disabling some extra mappings
self._username_id = {} # username: marked_id
self._phone_id = {} # phone: marked_id
def process(self, tlobject):
"""Processes all the found entities on the given TLObject,
unless .enabled is False.
Returns True if new input entities were added.
"""
if not self.enabled:
return False
# Save all input entities we know of
if not isinstance(tlobject, TLObject) and hasattr(tlobject, '__iter__'):
# This may be a list of users already for instance
return self.expand(tlobject)
entities = []
if hasattr(tlobject, 'chats') and hasattr(tlobject.chats, '__iter__'):
entities.extend(tlobject.chats)
if hasattr(tlobject, 'users') and hasattr(tlobject.users, '__iter__'):
entities.extend(tlobject.users)
return self.expand(entities)
def expand(self, entities):
"""Adds new input entities to the local database unconditionally.
Unknown types will be ignored.
"""
if not entities or not self.enabled:
return False
new = [] # Array of entities (User, Chat, or Channel)
new_input = {} # Dictionary of {entity_marked_id: access_hash}
for e in entities:
if not isinstance(e, TLObject):
continue
try:
p = utils.get_input_peer(e, allow_self=False)
new_input[utils.get_peer_id(p, add_mark=True)] = \
getattr(p, 'access_hash', 0) # chats won't have hash
if self.enabled_full:
if isinstance(e, User) \
or isinstance(e, Chat) \
or isinstance(e, Channel):
new.append(e)
except ValueError:
pass
with self._lock:
before = len(self._input_entities)
self._input_entities.update(new_input)
for e in new:
self._add_full_entity(e)
return len(self._input_entities) != before
def _add_full_entity(self, entity):
"""Adds a "full" entity (User, Chat or Channel, not "Input*"),
despite the value of self.enabled and self.enabled_full.
Not to be confused with UserFull, ChatFull, or ChannelFull,
"full" means simply not "Input*".
"""
marked_id = utils.get_peer_id(
utils.get_input_peer(entity, allow_self=False), add_mark=True
)
try:
old_entity = self._entities[marked_id]
old_entity.__dict__.update(entity.__dict__) # Keep old references
# Update must delete old username and phone
username = getattr(old_entity, 'username', None)
if username:
del self._username_id[username.lower()]
phone = getattr(old_entity, 'phone', None)
if phone:
del self._phone_id[phone]
except KeyError:
# Add new entity
self._entities[marked_id] = entity
# Always update username or phone if any
username = getattr(entity, 'username', None)
if username:
self._username_id[username.lower()] = marked_id
phone = getattr(entity, 'phone', None)
if phone:
self._username_id[phone] = marked_id
def __getitem__(self, key):
"""Accepts a digit only string as phone number,
otherwise it's treated as an username.
If an integer is given, it's treated as the ID of the desired User.
The ID given won't try to be guessed as the ID of a chat or channel,
as there may be an user with that ID, and it would be unreliable.
If a Peer is given (PeerUser, PeerChat, PeerChannel),
its specific entity is retrieved as User, Chat or Channel.
Note that megagroups are channels with .megagroup = True.
"""
if isinstance(key, str):
phone = EntityDatabase.parse_phone(key)
if phone:
return self._phone_id[phone]
else:
key = key.lstrip('@').lower()
return self._entities[self._username_id[key]]
if isinstance(key, int):
return self._entities[key] # normal IDs are assumed users
if isinstance(key, TLObject):
sc = type(key).SUBCLASS_OF_ID
if sc == 0x2d45687:
# Subclass of "Peer"
return self._entities[utils.get_peer_id(key, add_mark=True)]
elif sc in {0x2da17977, 0xc5af5d94, 0x6d44b7db}:
# Subclass of "User", "Chat" or "Channel"
return key
raise KeyError(key)
def __delitem__(self, key):
target = self[key]
del self._entities[key]
if getattr(target, 'username'):
del self._username_id[target.username]
# TODO Allow search by name by tokenizing the input and return a list
@staticmethod
def parse_phone(phone):
"""Parses the given phone, or returns None if it's invalid"""
if isinstance(phone, int):
return str(phone)
else:
phone = re.sub(r'[+()\s-]', '', str(phone))
if phone.isdigit():
return phone
def get_input_entity(self, peer):
try:
return self._input_entities[utils.get_peer_id(peer, add_mark=True)]
except ValueError as e:
raise KeyError(peer) from e
def get_input_list(self):
return list(self._input_entities.items())
def clear(self, target=None):
if target is None:
self._entities.clear()
else:
del self[target]

View File

@ -6,11 +6,8 @@ from base64 import b64encode, b64decode
from os.path import isfile as file_exists from os.path import isfile as file_exists
from threading import Lock from threading import Lock
from .. import helpers, utils from .entity_database import EntityDatabase
from ..tl.types import ( from .. import helpers
InputPeerUser, InputPeerChat, InputPeerChannel,
PeerUser, PeerChat, PeerChannel
)
class Session: class Session:
@ -70,8 +67,7 @@ class Session:
self.auth_key = None self.auth_key = None
self.layer = 0 self.layer = 0
self.salt = 0 # Unsigned long self.salt = 0 # Unsigned long
self._input_entities = {} # {marked_id: hash} self.entities = EntityDatabase() # Known and cached entities
self._entities_lock = Lock()
def save(self): def save(self):
"""Saves the current session object as session_user_id.session""" """Saves the current session object as session_user_id.session"""
@ -90,7 +86,7 @@ class Session:
if self.auth_key else None if self.auth_key else None
} }
if self.save_entities: if self.save_entities:
out_dict['entities'] = list(self._input_entities.items()) out_dict['entities'] = self.entities.get_input_list()
json.dump(out_dict, file) json.dump(out_dict, file)
@ -139,8 +135,7 @@ class Session:
key = b64decode(data['auth_key_data']) key = b64decode(data['auth_key_data'])
result.auth_key = AuthKey(data=key) result.auth_key = AuthKey(data=key)
for e_mid, e_hash in data.get('entities', []): result.entities = EntityDatabase(data.get('entities', []))
result._input_entities[e_mid] = e_hash
except (json.decoder.JSONDecodeError, UnicodeDecodeError): except (json.decoder.JSONDecodeError, UnicodeDecodeError):
pass pass
@ -186,58 +181,8 @@ class Session:
self.time_offset = correct - now self.time_offset = correct - now
def process_entities(self, tlobject): def process_entities(self, tlobject):
"""Processes all the found entities on the given TLObject, try:
unless .save_entities is False, and saves the session file. if self.entities.process(tlobject):
""" self.save() # Save if any new entities got added
if not self.save_entities: except:
return pass
# Save all input entities we know of
entities = []
if hasattr(tlobject, 'chats') and hasattr(tlobject.chats, '__iter__'):
entities.extend(tlobject.chats)
if hasattr(tlobject, 'users') and hasattr(tlobject.users, '__iter__'):
entities.extend(tlobject.users)
if self.add_entities(entities):
self.save() # Save if any new entities got added
def add_entities(self, entities):
"""Adds new input entities to the local database unconditionally.
Unknown types will be ignored.
"""
if not entities:
return False
new = {}
for e in entities:
try:
p = utils.get_input_peer(e)
new[utils.get_peer_id(p, add_mark=True)] = \
getattr(p, 'access_hash', 0) # chats won't have hash
except ValueError:
pass
with self._entities_lock:
before = len(self._input_entities)
self._input_entities.update(new)
return len(self._input_entities) != before
def get_input_entity(self, peer):
"""Gets an input entity known its Peer or a marked ID,
or raises KeyError if not found/invalid.
"""
if not isinstance(peer, int):
peer = utils.get_peer_id(peer, add_mark=True)
entity_hash = self._input_entities[peer]
entity_id, peer_class = utils.resolve_id(peer)
if peer_class == PeerUser:
return InputPeerUser(entity_id, entity_hash)
if peer_class == PeerChat:
return InputPeerChat(entity_id)
if peer_class == PeerChannel:
return InputPeerChannel(entity_id, entity_hash)
raise KeyError()

View File

@ -72,7 +72,7 @@ def _raise_cast_fail(entity, target):
.format(type(entity).__name__, target)) .format(type(entity).__name__, target))
def get_input_peer(entity): def get_input_peer(entity, allow_self=True):
"""Gets the input peer for the given "entity" (user, chat or channel). """Gets the input peer for the given "entity" (user, chat or channel).
A ValueError is raised if the given entity isn't a supported type.""" A ValueError is raised if the given entity isn't a supported type."""
if not isinstance(entity, TLObject): if not isinstance(entity, TLObject):
@ -82,7 +82,7 @@ def get_input_peer(entity):
return entity return entity
if isinstance(entity, User): if isinstance(entity, User):
if entity.is_self: if entity.is_self and allow_self:
return InputPeerSelf() return InputPeerSelf()
else: else:
return InputPeerUser(entity.id, entity.access_hash) return InputPeerUser(entity.id, entity.access_hash)
@ -304,6 +304,16 @@ def get_peer_id(peer, add_mark=False):
"""Finds the ID of the given peer, and optionally converts it to """Finds the ID of the given peer, and optionally converts it to
the "bot api" format if 'add_mark' is set to True. the "bot api" format if 'add_mark' is set to True.
""" """
if not isinstance(peer, TLObject):
if isinstance(peer, int):
return peer
else:
_raise_cast_fail(peer, 'int')
if type(peer).SUBCLASS_OF_ID not in {0x2d45687, 0xc91c90b6}:
# Not a Peer or an InputPeer, so first get its Input version
peer = get_input_peer(peer, allow_self=False)
if isinstance(peer, PeerUser) or isinstance(peer, InputPeerUser): if isinstance(peer, PeerUser) or isinstance(peer, InputPeerUser):
return peer.user_id return peer.user_id
else: else: