Merge pull request #296 from LonamiWebs/entity-db

Use a custom database for entities
This commit is contained in:
Lonami 2017-10-05 13:46:16 +02:00 committed by GitHub
commit 427a6aabaa
4 changed files with 259 additions and 136 deletions

View File

@ -1,7 +1,5 @@
import os
import re
from datetime import datetime, timedelta
from functools import lru_cache
from mimetypes import guess_type
try:
@ -17,6 +15,7 @@ from .errors import (
)
from .network import ConnectionMode
from .tl import TLObject
from .tl.entity_database import EntityDatabase
from .tl.functions.account import (
GetPasswordRequest
)
@ -127,7 +126,7 @@ class TelegramClient(TelegramBareClient):
def send_code_request(self, phone):
"""Sends a code request to the specified phone number"""
phone = self._parse_phone(phone)
phone = EntityDatabase.parse_phone(phone) or self._phone
result = self(SendCodeRequest(phone, self.api_id, self.api_hash))
self._phone = phone
self._phone_code_hash = result.phone_code_hash
@ -160,7 +159,7 @@ class TelegramClient(TelegramBareClient):
if phone and not code:
return self.send_code_request(phone)
elif code:
phone = self._parse_phone(phone)
phone = EntityDatabase.parse_phone(phone) or self._phone
phone_code_hash = phone_code_hash or self._phone_code_hash
if not phone:
raise ValueError(
@ -849,7 +848,6 @@ class TelegramClient(TelegramBareClient):
# region Small utilities to make users' life easier
@lru_cache()
def get_entity(self, entity):
"""Turns an entity into a valid Telegram user or chat.
If "entity" is a string which can be converted to an integer,
@ -866,71 +864,52 @@ class TelegramClient(TelegramBareClient):
If the entity is neither, and it's not a TLObject, an
error will be raised.
"""
# TODO Maybe cache both the contacts and the entities.
# If an user cannot be found, force a cache update through
# a public method (since users may change their username)
input_entity = None
if isinstance(entity, TLObject):
# crc32(b'InputPeer') and crc32(b'Peer')
if type(entity).SUBCLASS_OF_ID in (0xc91c90b6, 0x2d45687):
input_entity = self.get_input_entity(entity)
else:
# TODO Don't assume it's a valid entity
return entity
try:
return self.session.entities[entity]
except KeyError:
pass
elif isinstance(entity, int):
input_entity = self.get_input_entity(entity)
if isinstance(entity, int) or (
isinstance(entity, TLObject) and
# crc32(b'InputPeer') and crc32(b'Peer')
type(entity).SUBCLASS_OF_ID in (0xc91c90b6, 0x2d45687)):
ie = self.get_input_entity(entity)
if isinstance(ie, InputPeerUser):
self.session.process_entities(GetUsersRequest([ie]))
elif isinstance(ie, InputPeerChat):
self.session.process_entities(GetChatsRequest([ie.chat_id]))
elif isinstance(ie, InputPeerChannel):
self.session.process_entities(GetChannelsRequest([ie]))
if input_entity:
if isinstance(input_entity, InputPeerUser):
return self(GetUsersRequest([input_entity]))[0]
elif isinstance(input_entity, InputPeerChat):
return self(GetChatsRequest([input_entity.chat_id])).chats[0]
elif isinstance(input_entity, InputPeerChannel):
return self(GetChannelsRequest([input_entity])).chats[0]
# The result of Get*Request has been processed and the entity
# cached if it was found.
return self.session.entities[ie]
if isinstance(entity, str):
stripped_phone = self._parse_phone(entity, ignore_saved=True)
if stripped_phone.isdigit():
contacts = self(GetContactsRequest(0))
try:
return next(
u for u in contacts.users
if u.phone and u.phone.endswith(stripped_phone)
)
except StopIteration:
raise ValueError(
'Could not find user with phone {}, '
'add them to your contacts first'.format(entity)
)
else:
username = entity.strip('@').lower()
resolved = self(ResolveUsernameRequest(username))
for c in resolved.chats:
if getattr(c, 'username', '').lower() == username:
return c
for u in resolved.users:
if getattr(u, 'username', '').lower() == username:
return u
raise ValueError(
'Could not find user with username {}'.format(entity)
)
return self._get_entity_from_string(entity)
raise ValueError(
'Cannot turn "{}" into any entity (user or chat)'.format(entity)
)
def _parse_phone(self, phone, ignore_saved=False):
if isinstance(phone, int):
phone = str(phone)
elif phone:
phone = re.sub(r'[+()\s-]', '', phone)
if ignore_saved:
return phone
def _get_entity_from_string(self, string):
"""Gets an entity from the given string, which may be a phone or
an username, and processes all the found entities on the session.
"""
phone = EntityDatabase.parse_phone(string)
if phone:
entity = phone
self.session.process_entities(self(GetContactsRequest(0)))
else:
return phone or self._phone
entity = string.strip('@').lower()
self.session.process_entities(self(ResolveUsernameRequest(entity)))
try:
return self.session.entities[entity]
except KeyError:
raise ValueError(
'Could not find user with username {}'.format(entity)
)
def get_input_entity(self, peer):
"""Gets the input entity given its PeerUser, PeerChat, PeerChannel.
@ -940,11 +919,16 @@ class TelegramClient(TelegramBareClient):
If this Peer hasn't been seen before by the library, all dialogs
will loaded, and their entities saved to the session file.
If even after
If even after it's not found, a ValueError is raised.
"""
try:
# First try to get the entity from cache, otherwise figure it out
self.session.entities.get_input_entity(peer)
except KeyError:
pass
if isinstance(peer, str):
# Let .get_entity resolve the username or phone (full entity)
peer = self.get_entity(peer)
return utils.get_input_peer(self._get_entity_from_string(peer))
is_peer = False
if isinstance(peer, int):
@ -964,16 +948,12 @@ class TelegramClient(TelegramBareClient):
'Cannot turn "{}" into an input entity.'.format(peer)
)
try:
return self.session.get_input_entity(peer)
except KeyError:
pass
if self.session.save_entities:
# Not found, look in the dialogs (this will save the users)
self.get_dialogs(limit=None)
try:
return self.session.get_input_entity(peer)
self.session.entities.get_input_entity(peer)
except KeyError:
pass

View File

@ -0,0 +1,188 @@
from threading import Lock
import re
from .. import utils
from ..tl import TLObject
from ..tl.types import User, Chat, Channel
class EntityDatabase:
def __init__(self, input_list=None, enabled=True, enabled_full=True):
"""Creates a new entity database with an initial load of "Input"
entities, if any.
If 'enabled', input entities will be saved. The whole entity
will be saved if both 'enabled' and 'enabled_full' are True.
"""
self.enabled = enabled
self.enabled_full = enabled_full
self._lock = Lock()
self._entities = {} # marked_id: user|chat|channel
if input_list:
self._input_entities = {k: v for k, v in input_list}
else:
self._input_entities = {} # marked_id: hash
# TODO Allow disabling some extra mappings
self._username_id = {} # username: marked_id
self._phone_id = {} # phone: marked_id
def process(self, tlobject):
"""Processes all the found entities on the given TLObject,
unless .enabled is False.
Returns True if new input entities were added.
"""
if not self.enabled:
return False
# Save all input entities we know of
if not isinstance(tlobject, TLObject) and hasattr(tlobject, '__iter__'):
# This may be a list of users already for instance
return self.expand(tlobject)
entities = []
if hasattr(tlobject, 'chats') and hasattr(tlobject.chats, '__iter__'):
entities.extend(tlobject.chats)
if hasattr(tlobject, 'users') and hasattr(tlobject.users, '__iter__'):
entities.extend(tlobject.users)
return self.expand(entities)
def expand(self, entities):
"""Adds new input entities to the local database unconditionally.
Unknown types will be ignored.
"""
if not entities or not self.enabled:
return False
new = [] # Array of entities (User, Chat, or Channel)
new_input = {} # Dictionary of {entity_marked_id: access_hash}
for e in entities:
if not isinstance(e, TLObject):
continue
try:
p = utils.get_input_peer(e, allow_self=False)
new_input[utils.get_peer_id(p, add_mark=True)] = \
getattr(p, 'access_hash', 0) # chats won't have hash
if self.enabled_full:
if isinstance(e, User) \
or isinstance(e, Chat) \
or isinstance(e, Channel):
new.append(e)
except ValueError:
pass
with self._lock:
before = len(self._input_entities)
self._input_entities.update(new_input)
for e in new:
self._add_full_entity(e)
return len(self._input_entities) != before
def _add_full_entity(self, entity):
"""Adds a "full" entity (User, Chat or Channel, not "Input*"),
despite the value of self.enabled and self.enabled_full.
Not to be confused with UserFull, ChatFull, or ChannelFull,
"full" means simply not "Input*".
"""
marked_id = utils.get_peer_id(
utils.get_input_peer(entity, allow_self=False), add_mark=True
)
try:
old_entity = self._entities[marked_id]
old_entity.__dict__.update(entity.__dict__) # Keep old references
# Update must delete old username and phone
username = getattr(old_entity, 'username', None)
if username:
del self._username_id[username.lower()]
phone = getattr(old_entity, 'phone', None)
if phone:
del self._phone_id[phone]
except KeyError:
# Add new entity
self._entities[marked_id] = entity
# Always update username or phone if any
username = getattr(entity, 'username', None)
if username:
self._username_id[username.lower()] = marked_id
phone = getattr(entity, 'phone', None)
if phone:
self._username_id[phone] = marked_id
def __getitem__(self, key):
"""Accepts a digit only string as phone number,
otherwise it's treated as an username.
If an integer is given, it's treated as the ID of the desired User.
The ID given won't try to be guessed as the ID of a chat or channel,
as there may be an user with that ID, and it would be unreliable.
If a Peer is given (PeerUser, PeerChat, PeerChannel),
its specific entity is retrieved as User, Chat or Channel.
Note that megagroups are channels with .megagroup = True.
"""
if isinstance(key, str):
phone = EntityDatabase.parse_phone(key)
if phone:
return self._phone_id[phone]
else:
key = key.lstrip('@').lower()
return self._entities[self._username_id[key]]
if isinstance(key, int):
return self._entities[key] # normal IDs are assumed users
if isinstance(key, TLObject):
sc = type(key).SUBCLASS_OF_ID
if sc == 0x2d45687:
# Subclass of "Peer"
return self._entities[utils.get_peer_id(key, add_mark=True)]
elif sc in {0x2da17977, 0xc5af5d94, 0x6d44b7db}:
# Subclass of "User", "Chat" or "Channel"
return key
raise KeyError(key)
def __delitem__(self, key):
target = self[key]
del self._entities[key]
if getattr(target, 'username'):
del self._username_id[target.username]
# TODO Allow search by name by tokenizing the input and return a list
@staticmethod
def parse_phone(phone):
"""Parses the given phone, or returns None if it's invalid"""
if isinstance(phone, int):
return str(phone)
else:
phone = re.sub(r'[+()\s-]', '', str(phone))
if phone.isdigit():
return phone
def get_input_entity(self, peer):
try:
return self._input_entities[utils.get_peer_id(peer, add_mark=True)]
except ValueError as e:
raise KeyError(peer) from e
def get_input_list(self):
return list(self._input_entities.items())
def clear(self, target=None):
if target is None:
self._entities.clear()
else:
del self[target]

View File

@ -6,11 +6,8 @@ from base64 import b64encode, b64decode
from os.path import isfile as file_exists
from threading import Lock
from .. import helpers, utils
from ..tl.types import (
InputPeerUser, InputPeerChat, InputPeerChannel,
PeerUser, PeerChat, PeerChannel
)
from .entity_database import EntityDatabase
from .. import helpers
class Session:
@ -70,8 +67,7 @@ class Session:
self.auth_key = None
self.layer = 0
self.salt = 0 # Unsigned long
self._input_entities = {} # {marked_id: hash}
self._entities_lock = Lock()
self.entities = EntityDatabase() # Known and cached entities
def save(self):
"""Saves the current session object as session_user_id.session"""
@ -90,7 +86,7 @@ class Session:
if self.auth_key else None
}
if self.save_entities:
out_dict['entities'] = list(self._input_entities.items())
out_dict['entities'] = self.entities.get_input_list()
json.dump(out_dict, file)
@ -139,8 +135,7 @@ class Session:
key = b64decode(data['auth_key_data'])
result.auth_key = AuthKey(data=key)
for e_mid, e_hash in data.get('entities', []):
result._input_entities[e_mid] = e_hash
result.entities = EntityDatabase(data.get('entities', []))
except (json.decoder.JSONDecodeError, UnicodeDecodeError):
pass
@ -186,58 +181,8 @@ class Session:
self.time_offset = correct - now
def process_entities(self, tlobject):
"""Processes all the found entities on the given TLObject,
unless .save_entities is False, and saves the session file.
"""
if not self.save_entities:
return
# Save all input entities we know of
entities = []
if hasattr(tlobject, 'chats') and hasattr(tlobject.chats, '__iter__'):
entities.extend(tlobject.chats)
if hasattr(tlobject, 'users') and hasattr(tlobject.users, '__iter__'):
entities.extend(tlobject.users)
if self.add_entities(entities):
self.save() # Save if any new entities got added
def add_entities(self, entities):
"""Adds new input entities to the local database unconditionally.
Unknown types will be ignored.
"""
if not entities:
return False
new = {}
for e in entities:
try:
p = utils.get_input_peer(e)
new[utils.get_peer_id(p, add_mark=True)] = \
getattr(p, 'access_hash', 0) # chats won't have hash
except ValueError:
pass
with self._entities_lock:
before = len(self._input_entities)
self._input_entities.update(new)
return len(self._input_entities) != before
def get_input_entity(self, peer):
"""Gets an input entity known its Peer or a marked ID,
or raises KeyError if not found/invalid.
"""
if not isinstance(peer, int):
peer = utils.get_peer_id(peer, add_mark=True)
entity_hash = self._input_entities[peer]
entity_id, peer_class = utils.resolve_id(peer)
if peer_class == PeerUser:
return InputPeerUser(entity_id, entity_hash)
if peer_class == PeerChat:
return InputPeerChat(entity_id)
if peer_class == PeerChannel:
return InputPeerChannel(entity_id, entity_hash)
raise KeyError()
try:
if self.entities.process(tlobject):
self.save() # Save if any new entities got added
except:
pass

View File

@ -72,7 +72,7 @@ def _raise_cast_fail(entity, target):
.format(type(entity).__name__, target))
def get_input_peer(entity):
def get_input_peer(entity, allow_self=True):
"""Gets the input peer for the given "entity" (user, chat or channel).
A ValueError is raised if the given entity isn't a supported type."""
if not isinstance(entity, TLObject):
@ -82,7 +82,7 @@ def get_input_peer(entity):
return entity
if isinstance(entity, User):
if entity.is_self:
if entity.is_self and allow_self:
return InputPeerSelf()
else:
return InputPeerUser(entity.id, entity.access_hash)
@ -304,6 +304,16 @@ def get_peer_id(peer, add_mark=False):
"""Finds the ID of the given peer, and optionally converts it to
the "bot api" format if 'add_mark' is set to True.
"""
if not isinstance(peer, TLObject):
if isinstance(peer, int):
return peer
else:
_raise_cast_fail(peer, 'int')
if type(peer).SUBCLASS_OF_ID not in {0x2d45687, 0xc91c90b6}:
# Not a Peer or an InputPeer, so first get its Input version
peer = get_input_peer(peer, allow_self=False)
if isinstance(peer, PeerUser) or isinstance(peer, InputPeerUser):
return peer.user_id
else: