mirror of
https://github.com/LonamiWebs/Telethon.git
synced 2025-01-24 08:14:14 +03:00
Use EntityDatabase in the Session class
This commit is contained in:
parent
5be9df0eec
commit
a0fc5ed54e
|
@ -1,75 +0,0 @@
|
|||
from . import utils
|
||||
from .tl import TLObject
|
||||
|
||||
|
||||
class EntityDatabase:
|
||||
def __init__(self, enabled=True):
|
||||
self.enabled = enabled
|
||||
|
||||
self._entities = {} # marked_id: user|chat|channel
|
||||
|
||||
# TODO Allow disabling some extra mappings
|
||||
self._username_id = {} # username: marked_id
|
||||
|
||||
def add(self, entity):
|
||||
if not self.enabled:
|
||||
return
|
||||
|
||||
# Adds or updates the given entity
|
||||
marked_id = utils.get_peer_id(entity, add_mark=True)
|
||||
try:
|
||||
old_entity = self._entities[marked_id]
|
||||
old_entity.__dict__.update(entity) # Keep old references
|
||||
|
||||
# Update must delete old username
|
||||
username = getattr(old_entity, 'username', None)
|
||||
if username:
|
||||
del self._username_id[username.lower()]
|
||||
except KeyError:
|
||||
# Add new entity
|
||||
self._entities[marked_id] = entity
|
||||
|
||||
# Always update username if any
|
||||
username = getattr(entity, 'username', None)
|
||||
if username:
|
||||
self._username_id[username.lower()] = marked_id
|
||||
|
||||
def __getitem__(self, key):
|
||||
"""Accepts a digit only string as phone number,
|
||||
otherwise it's treated as an username.
|
||||
|
||||
If an integer is given, it's treated as the ID of the desired User.
|
||||
The ID given won't try to be guessed as the ID of a chat or channel,
|
||||
as there may be an user with that ID, and it would be unreliable.
|
||||
|
||||
If a Peer is given (PeerUser, PeerChat, PeerChannel),
|
||||
its specific entity is retrieved as User, Chat or Channel.
|
||||
Note that megagroups are channels with .megagroup = True.
|
||||
"""
|
||||
if isinstance(key, str):
|
||||
# TODO Parse phone properly, currently only usernames
|
||||
key = key.lstrip('@').lower()
|
||||
# TODO Use the client to return from username if not found
|
||||
return self._entities[self._username_id[key]]
|
||||
|
||||
if isinstance(key, int):
|
||||
return self._entities[key] # normal IDs are assumed users
|
||||
|
||||
if isinstance(key, TLObject) and type(key).SUBCLASS_OF_ID == 0x2d45687:
|
||||
return self._entities[utils.get_peer_id(key, add_mark=True)]
|
||||
|
||||
raise KeyError(key)
|
||||
|
||||
def __delitem__(self, key):
|
||||
target = self[key]
|
||||
del self._entities[key]
|
||||
if getattr(target, 'username'):
|
||||
del self._username_id[target.username]
|
||||
|
||||
# TODO Allow search by name by tokenizing the input and return a list
|
||||
|
||||
def clear(self, target=None):
|
||||
if target is None:
|
||||
self._entities.clear()
|
||||
else:
|
||||
del self[target]
|
140
telethon/tl/entity_database.py
Normal file
140
telethon/tl/entity_database.py
Normal file
|
@ -0,0 +1,140 @@
|
|||
from threading import Lock
|
||||
|
||||
from .. import utils
|
||||
from ..tl import TLObject
|
||||
from ..tl.types import User, Chat, Channel
|
||||
|
||||
|
||||
class EntityDatabase:
|
||||
def __init__(self, input_list=None, enabled=True):
|
||||
self.enabled = enabled
|
||||
|
||||
self._lock = Lock()
|
||||
self._entities = {} # marked_id: user|chat|channel
|
||||
|
||||
if input_list:
|
||||
self._input_entities = {k: v for k, v in input_list}
|
||||
else:
|
||||
self._input_entities = {} # marked_id: hash
|
||||
|
||||
# TODO Allow disabling some extra mappings
|
||||
self._username_id = {} # username: marked_id
|
||||
|
||||
def process(self, tlobject):
|
||||
"""Processes all the found entities on the given TLObject,
|
||||
unless .enabled is False.
|
||||
|
||||
Returns True if new input entities were added.
|
||||
"""
|
||||
if not self.enabled:
|
||||
return False
|
||||
|
||||
# Save all input entities we know of
|
||||
entities = []
|
||||
if hasattr(tlobject, 'chats') and hasattr(tlobject.chats, '__iter__'):
|
||||
entities.extend(tlobject.chats)
|
||||
if hasattr(tlobject, 'users') and hasattr(tlobject.users, '__iter__'):
|
||||
entities.extend(tlobject.users)
|
||||
|
||||
return self.expand(entities)
|
||||
|
||||
def expand(self, entities):
|
||||
"""Adds new input entities to the local database unconditionally.
|
||||
Unknown types will be ignored.
|
||||
"""
|
||||
if not entities or not self.enabled:
|
||||
return False
|
||||
|
||||
new = [] # Array of entities (User, Chat, or Channel)
|
||||
new_input = {} # Dictionary of {entity_marked_id: access_hash}
|
||||
for e in entities:
|
||||
if not isinstance(e, TLObject):
|
||||
continue
|
||||
|
||||
try:
|
||||
p = utils.get_input_peer(e)
|
||||
new_input[utils.get_peer_id(p, add_mark=True)] = \
|
||||
getattr(p, 'access_hash', 0) # chats won't have hash
|
||||
|
||||
if isinstance(e, User) \
|
||||
or isinstance(e, Chat) \
|
||||
or isinstance(e, Channel):
|
||||
new.append(e)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
with self._lock:
|
||||
before = len(self._input_entities)
|
||||
self._input_entities.update(new_input)
|
||||
for e in new:
|
||||
self._add_full_entity(e)
|
||||
return len(self._input_entities) != before
|
||||
|
||||
def _add_full_entity(self, entity):
|
||||
"""Adds a "full" entity (User, Chat or Channel, not "Input*").
|
||||
|
||||
Not to be confused with UserFull, ChatFull, or ChannelFull,
|
||||
"full" means simply not "Input*".
|
||||
"""
|
||||
marked_id = utils.get_peer_id(
|
||||
utils.get_input_peer(entity), add_mark=True
|
||||
)
|
||||
try:
|
||||
old_entity = self._entities[marked_id]
|
||||
old_entity.__dict__.update(entity.__dict__) # Keep old references
|
||||
|
||||
# Update must delete old username
|
||||
username = getattr(old_entity, 'username', None)
|
||||
if username:
|
||||
del self._username_id[username.lower()]
|
||||
except KeyError:
|
||||
# Add new entity
|
||||
self._entities[marked_id] = entity
|
||||
|
||||
# Always update username if any
|
||||
username = getattr(entity, 'username', None)
|
||||
if username:
|
||||
self._username_id[username.lower()] = marked_id
|
||||
|
||||
def __getitem__(self, key):
|
||||
"""Accepts a digit only string as phone number,
|
||||
otherwise it's treated as an username.
|
||||
|
||||
If an integer is given, it's treated as the ID of the desired User.
|
||||
The ID given won't try to be guessed as the ID of a chat or channel,
|
||||
as there may be an user with that ID, and it would be unreliable.
|
||||
|
||||
If a Peer is given (PeerUser, PeerChat, PeerChannel),
|
||||
its specific entity is retrieved as User, Chat or Channel.
|
||||
Note that megagroups are channels with .megagroup = True.
|
||||
"""
|
||||
if isinstance(key, str):
|
||||
# TODO Parse phone properly, currently only usernames
|
||||
key = key.lstrip('@').lower()
|
||||
# TODO Use the client to return from username if not found
|
||||
return self._entities[self._username_id[key]]
|
||||
|
||||
if isinstance(key, int):
|
||||
return self._entities[key] # normal IDs are assumed users
|
||||
|
||||
if isinstance(key, TLObject) and type(key).SUBCLASS_OF_ID == 0x2d45687:
|
||||
return self._entities[utils.get_peer_id(key, add_mark=True)]
|
||||
|
||||
raise KeyError(key)
|
||||
|
||||
def __delitem__(self, key):
|
||||
target = self[key]
|
||||
del self._entities[key]
|
||||
if getattr(target, 'username'):
|
||||
del self._username_id[target.username]
|
||||
|
||||
# TODO Allow search by name by tokenizing the input and return a list
|
||||
|
||||
def get_input_list(self):
|
||||
return list(self._input_entities.items())
|
||||
|
||||
def clear(self, target=None):
|
||||
if target is None:
|
||||
self._entities.clear()
|
||||
else:
|
||||
del self[target]
|
|
@ -6,11 +6,8 @@ from base64 import b64encode, b64decode
|
|||
from os.path import isfile as file_exists
|
||||
from threading import Lock
|
||||
|
||||
from .. import helpers, utils
|
||||
from ..tl.types import (
|
||||
InputPeerUser, InputPeerChat, InputPeerChannel,
|
||||
PeerUser, PeerChat, PeerChannel
|
||||
)
|
||||
from .entity_database import EntityDatabase
|
||||
from .. import helpers
|
||||
|
||||
|
||||
class Session:
|
||||
|
@ -70,8 +67,7 @@ class Session:
|
|||
self.auth_key = None
|
||||
self.layer = 0
|
||||
self.salt = 0 # Unsigned long
|
||||
self._input_entities = {} # {marked_id: hash}
|
||||
self._entities_lock = Lock()
|
||||
self.entities = EntityDatabase() # Known and cached entities
|
||||
|
||||
def save(self):
|
||||
"""Saves the current session object as session_user_id.session"""
|
||||
|
@ -90,7 +86,7 @@ class Session:
|
|||
if self.auth_key else None
|
||||
}
|
||||
if self.save_entities:
|
||||
out_dict['entities'] = list(self._input_entities.items())
|
||||
out_dict['entities'] = self.entities.get_input_list()
|
||||
|
||||
json.dump(out_dict, file)
|
||||
|
||||
|
@ -139,8 +135,7 @@ class Session:
|
|||
key = b64decode(data['auth_key_data'])
|
||||
result.auth_key = AuthKey(data=key)
|
||||
|
||||
for e_mid, e_hash in data.get('entities', []):
|
||||
result._input_entities[e_mid] = e_hash
|
||||
result.entities = EntityDatabase(data.get('entities', []))
|
||||
|
||||
except (json.decoder.JSONDecodeError, UnicodeDecodeError):
|
||||
pass
|
||||
|
@ -186,58 +181,5 @@ class Session:
|
|||
self.time_offset = correct - now
|
||||
|
||||
def process_entities(self, tlobject):
|
||||
"""Processes all the found entities on the given TLObject,
|
||||
unless .save_entities is False, and saves the session file.
|
||||
"""
|
||||
if not self.save_entities:
|
||||
return
|
||||
|
||||
# Save all input entities we know of
|
||||
entities = []
|
||||
if hasattr(tlobject, 'chats') and hasattr(tlobject.chats, '__iter__'):
|
||||
entities.extend(tlobject.chats)
|
||||
if hasattr(tlobject, 'users') and hasattr(tlobject.users, '__iter__'):
|
||||
entities.extend(tlobject.users)
|
||||
|
||||
if self.add_entities(entities):
|
||||
if self.entities.process(tlobject):
|
||||
self.save() # Save if any new entities got added
|
||||
|
||||
def add_entities(self, entities):
|
||||
"""Adds new input entities to the local database unconditionally.
|
||||
Unknown types will be ignored.
|
||||
"""
|
||||
if not entities:
|
||||
return False
|
||||
|
||||
new = {}
|
||||
for e in entities:
|
||||
try:
|
||||
p = utils.get_input_peer(e)
|
||||
new[utils.get_peer_id(p, add_mark=True)] = \
|
||||
getattr(p, 'access_hash', 0) # chats won't have hash
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
with self._entities_lock:
|
||||
before = len(self._input_entities)
|
||||
self._input_entities.update(new)
|
||||
return len(self._input_entities) != before
|
||||
|
||||
def get_input_entity(self, peer):
|
||||
"""Gets an input entity known its Peer or a marked ID,
|
||||
or raises KeyError if not found/invalid.
|
||||
"""
|
||||
if not isinstance(peer, int):
|
||||
peer = utils.get_peer_id(peer, add_mark=True)
|
||||
|
||||
entity_hash = self._input_entities[peer]
|
||||
entity_id, peer_class = utils.resolve_id(peer)
|
||||
|
||||
if peer_class == PeerUser:
|
||||
return InputPeerUser(entity_id, entity_hash)
|
||||
if peer_class == PeerChat:
|
||||
return InputPeerChat(entity_id)
|
||||
if peer_class == PeerChannel:
|
||||
return InputPeerChannel(entity_id, entity_hash)
|
||||
|
||||
raise KeyError()
|
||||
|
|
Loading…
Reference in New Issue
Block a user