Use EntityDatabase in the Session class

This commit is contained in:
Lonami Exo 2017-10-04 21:02:45 +02:00
parent 5be9df0eec
commit a0fc5ed54e
3 changed files with 146 additions and 139 deletions

View File

@ -1,75 +0,0 @@
from . import utils
from .tl import TLObject
class EntityDatabase:
def __init__(self, enabled=True):
self.enabled = enabled
self._entities = {} # marked_id: user|chat|channel
# TODO Allow disabling some extra mappings
self._username_id = {} # username: marked_id
def add(self, entity):
if not self.enabled:
return
# Adds or updates the given entity
marked_id = utils.get_peer_id(entity, add_mark=True)
try:
old_entity = self._entities[marked_id]
old_entity.__dict__.update(entity) # Keep old references
# Update must delete old username
username = getattr(old_entity, 'username', None)
if username:
del self._username_id[username.lower()]
except KeyError:
# Add new entity
self._entities[marked_id] = entity
# Always update username if any
username = getattr(entity, 'username', None)
if username:
self._username_id[username.lower()] = marked_id
def __getitem__(self, key):
"""Accepts a digit only string as phone number,
otherwise it's treated as an username.
If an integer is given, it's treated as the ID of the desired User.
The ID given won't try to be guessed as the ID of a chat or channel,
as there may be an user with that ID, and it would be unreliable.
If a Peer is given (PeerUser, PeerChat, PeerChannel),
its specific entity is retrieved as User, Chat or Channel.
Note that megagroups are channels with .megagroup = True.
"""
if isinstance(key, str):
# TODO Parse phone properly, currently only usernames
key = key.lstrip('@').lower()
# TODO Use the client to return from username if not found
return self._entities[self._username_id[key]]
if isinstance(key, int):
return self._entities[key] # normal IDs are assumed users
if isinstance(key, TLObject) and type(key).SUBCLASS_OF_ID == 0x2d45687:
return self._entities[utils.get_peer_id(key, add_mark=True)]
raise KeyError(key)
def __delitem__(self, key):
target = self[key]
del self._entities[key]
if getattr(target, 'username'):
del self._username_id[target.username]
# TODO Allow search by name by tokenizing the input and return a list
def clear(self, target=None):
if target is None:
self._entities.clear()
else:
del self[target]

View File

@ -0,0 +1,140 @@
from threading import Lock
from .. import utils
from ..tl import TLObject
from ..tl.types import User, Chat, Channel
class EntityDatabase:
def __init__(self, input_list=None, enabled=True):
self.enabled = enabled
self._lock = Lock()
self._entities = {} # marked_id: user|chat|channel
if input_list:
self._input_entities = {k: v for k, v in input_list}
else:
self._input_entities = {} # marked_id: hash
# TODO Allow disabling some extra mappings
self._username_id = {} # username: marked_id
def process(self, tlobject):
"""Processes all the found entities on the given TLObject,
unless .enabled is False.
Returns True if new input entities were added.
"""
if not self.enabled:
return False
# Save all input entities we know of
entities = []
if hasattr(tlobject, 'chats') and hasattr(tlobject.chats, '__iter__'):
entities.extend(tlobject.chats)
if hasattr(tlobject, 'users') and hasattr(tlobject.users, '__iter__'):
entities.extend(tlobject.users)
return self.expand(entities)
def expand(self, entities):
"""Adds new input entities to the local database unconditionally.
Unknown types will be ignored.
"""
if not entities or not self.enabled:
return False
new = [] # Array of entities (User, Chat, or Channel)
new_input = {} # Dictionary of {entity_marked_id: access_hash}
for e in entities:
if not isinstance(e, TLObject):
continue
try:
p = utils.get_input_peer(e)
new_input[utils.get_peer_id(p, add_mark=True)] = \
getattr(p, 'access_hash', 0) # chats won't have hash
if isinstance(e, User) \
or isinstance(e, Chat) \
or isinstance(e, Channel):
new.append(e)
except ValueError:
pass
with self._lock:
before = len(self._input_entities)
self._input_entities.update(new_input)
for e in new:
self._add_full_entity(e)
return len(self._input_entities) != before
def _add_full_entity(self, entity):
"""Adds a "full" entity (User, Chat or Channel, not "Input*").
Not to be confused with UserFull, ChatFull, or ChannelFull,
"full" means simply not "Input*".
"""
marked_id = utils.get_peer_id(
utils.get_input_peer(entity), add_mark=True
)
try:
old_entity = self._entities[marked_id]
old_entity.__dict__.update(entity.__dict__) # Keep old references
# Update must delete old username
username = getattr(old_entity, 'username', None)
if username:
del self._username_id[username.lower()]
except KeyError:
# Add new entity
self._entities[marked_id] = entity
# Always update username if any
username = getattr(entity, 'username', None)
if username:
self._username_id[username.lower()] = marked_id
def __getitem__(self, key):
"""Accepts a digit only string as phone number,
otherwise it's treated as an username.
If an integer is given, it's treated as the ID of the desired User.
The ID given won't try to be guessed as the ID of a chat or channel,
as there may be an user with that ID, and it would be unreliable.
If a Peer is given (PeerUser, PeerChat, PeerChannel),
its specific entity is retrieved as User, Chat or Channel.
Note that megagroups are channels with .megagroup = True.
"""
if isinstance(key, str):
# TODO Parse phone properly, currently only usernames
key = key.lstrip('@').lower()
# TODO Use the client to return from username if not found
return self._entities[self._username_id[key]]
if isinstance(key, int):
return self._entities[key] # normal IDs are assumed users
if isinstance(key, TLObject) and type(key).SUBCLASS_OF_ID == 0x2d45687:
return self._entities[utils.get_peer_id(key, add_mark=True)]
raise KeyError(key)
def __delitem__(self, key):
target = self[key]
del self._entities[key]
if getattr(target, 'username'):
del self._username_id[target.username]
# TODO Allow search by name by tokenizing the input and return a list
def get_input_list(self):
return list(self._input_entities.items())
def clear(self, target=None):
if target is None:
self._entities.clear()
else:
del self[target]

View File

@ -6,11 +6,8 @@ from base64 import b64encode, b64decode
from os.path import isfile as file_exists from os.path import isfile as file_exists
from threading import Lock from threading import Lock
from .. import helpers, utils from .entity_database import EntityDatabase
from ..tl.types import ( from .. import helpers
InputPeerUser, InputPeerChat, InputPeerChannel,
PeerUser, PeerChat, PeerChannel
)
class Session: class Session:
@ -70,8 +67,7 @@ class Session:
self.auth_key = None self.auth_key = None
self.layer = 0 self.layer = 0
self.salt = 0 # Unsigned long self.salt = 0 # Unsigned long
self._input_entities = {} # {marked_id: hash} self.entities = EntityDatabase() # Known and cached entities
self._entities_lock = Lock()
def save(self): def save(self):
"""Saves the current session object as session_user_id.session""" """Saves the current session object as session_user_id.session"""
@ -90,7 +86,7 @@ class Session:
if self.auth_key else None if self.auth_key else None
} }
if self.save_entities: if self.save_entities:
out_dict['entities'] = list(self._input_entities.items()) out_dict['entities'] = self.entities.get_input_list()
json.dump(out_dict, file) json.dump(out_dict, file)
@ -139,8 +135,7 @@ class Session:
key = b64decode(data['auth_key_data']) key = b64decode(data['auth_key_data'])
result.auth_key = AuthKey(data=key) result.auth_key = AuthKey(data=key)
for e_mid, e_hash in data.get('entities', []): result.entities = EntityDatabase(data.get('entities', []))
result._input_entities[e_mid] = e_hash
except (json.decoder.JSONDecodeError, UnicodeDecodeError): except (json.decoder.JSONDecodeError, UnicodeDecodeError):
pass pass
@ -186,58 +181,5 @@ class Session:
self.time_offset = correct - now self.time_offset = correct - now
def process_entities(self, tlobject): def process_entities(self, tlobject):
"""Processes all the found entities on the given TLObject, if self.entities.process(tlobject):
unless .save_entities is False, and saves the session file.
"""
if not self.save_entities:
return
# Save all input entities we know of
entities = []
if hasattr(tlobject, 'chats') and hasattr(tlobject.chats, '__iter__'):
entities.extend(tlobject.chats)
if hasattr(tlobject, 'users') and hasattr(tlobject.users, '__iter__'):
entities.extend(tlobject.users)
if self.add_entities(entities):
self.save() # Save if any new entities got added self.save() # Save if any new entities got added
def add_entities(self, entities):
"""Adds new input entities to the local database unconditionally.
Unknown types will be ignored.
"""
if not entities:
return False
new = {}
for e in entities:
try:
p = utils.get_input_peer(e)
new[utils.get_peer_id(p, add_mark=True)] = \
getattr(p, 'access_hash', 0) # chats won't have hash
except ValueError:
pass
with self._entities_lock:
before = len(self._input_entities)
self._input_entities.update(new)
return len(self._input_entities) != before
def get_input_entity(self, peer):
"""Gets an input entity known its Peer or a marked ID,
or raises KeyError if not found/invalid.
"""
if not isinstance(peer, int):
peer = utils.get_peer_id(peer, add_mark=True)
entity_hash = self._input_entities[peer]
entity_id, peer_class = utils.resolve_id(peer)
if peer_class == PeerUser:
return InputPeerUser(entity_id, entity_hash)
if peer_class == PeerChat:
return InputPeerChat(entity_id)
if peer_class == PeerChannel:
return InputPeerChannel(entity_id, entity_hash)
raise KeyError()