Remove custom EntityDatabase and use sqlite3 instead

There are still a few things to change, like cleaning up the
code and actually caching the entities as a whole (currently,
although the username/phone/name can be used to fetch their
input version which is an improvement, their full version
needs to be re-fetched. Maybe it's a good thing though?)
This commit is contained in:
Lonami Exo 2017-12-27 00:50:09 +01:00
parent 0a4849b150
commit aef96f1b68
4 changed files with 181 additions and 303 deletions

View File

@ -19,7 +19,6 @@ from .errors import (
from .network import ConnectionMode from .network import ConnectionMode
from .tl import TLObject from .tl import TLObject
from .tl.custom import Draft, Dialog from .tl.custom import Draft, Dialog
from .tl.entity_database import EntityDatabase
from .tl.functions.account import ( from .tl.functions.account import (
GetPasswordRequest GetPasswordRequest
) )
@ -144,7 +143,7 @@ class TelegramClient(TelegramBareClient):
:return auth.SentCode: :return auth.SentCode:
Information about the result of the request. Information about the result of the request.
""" """
phone = EntityDatabase.parse_phone(phone) or self._phone phone = utils.parse_phone(phone) or self._phone
if not self._phone_code_hash: if not self._phone_code_hash:
result = self(SendCodeRequest(phone, self.api_id, self.api_hash)) result = self(SendCodeRequest(phone, self.api_id, self.api_hash))
@ -188,7 +187,7 @@ class TelegramClient(TelegramBareClient):
if phone and not code: if phone and not code:
return self.send_code_request(phone) return self.send_code_request(phone)
elif code: elif code:
phone = EntityDatabase.parse_phone(phone) or self._phone phone = utils.parse_phone(phone) or self._phone
phone_code_hash = phone_code_hash or self._phone_code_hash phone_code_hash = phone_code_hash or self._phone_code_hash
if not phone: if not phone:
raise ValueError( raise ValueError(
@ -1009,12 +1008,8 @@ class TelegramClient(TelegramBareClient):
may be out of date. may be out of date.
:return: :return:
""" """
if not force_fetch: # TODO Actually cache {id: entities} again
# Try to use cache unless we want to force a fetch # >>> if not force_fetch: reuse cached
try:
return self.session.entities[entity]
except KeyError:
pass
if isinstance(entity, int) or ( if isinstance(entity, int) or (
isinstance(entity, TLObject) and isinstance(entity, TLObject) and
@ -1022,36 +1017,38 @@ class TelegramClient(TelegramBareClient):
type(entity).SUBCLASS_OF_ID in (0xc91c90b6, 0x2d45687)): type(entity).SUBCLASS_OF_ID in (0xc91c90b6, 0x2d45687)):
ie = self.get_input_entity(entity) ie = self.get_input_entity(entity)
if isinstance(ie, InputPeerUser): if isinstance(ie, InputPeerUser):
self(GetUsersRequest([ie])) return self(GetUsersRequest([ie]))[0]
elif isinstance(ie, InputPeerChat): elif isinstance(ie, InputPeerChat):
self(GetChatsRequest([ie.chat_id])) return self(GetChatsRequest([ie.chat_id])).chats[0]
elif isinstance(ie, InputPeerChannel): elif isinstance(ie, InputPeerChannel):
self(GetChannelsRequest([ie])) return self(GetChannelsRequest([ie])).chats[0]
try:
# session.process_entities has been called in the MtProtoSender
# with the result of these calls, so they should now be on the
# entities database.
return self.session.entities[ie]
except KeyError:
pass
if isinstance(entity, str): if isinstance(entity, str):
return self._get_entity_from_string(entity) # TODO This probably can be done better...
invite = self._load_entity_from_string(entity)
if invite:
return invite
return self.get_entity(self.session.get_input_entity(entity))
raise ValueError( raise ValueError(
'Cannot turn "{}" into any entity (user or chat)'.format(entity) 'Cannot turn "{}" into any entity (user or chat)'.format(entity)
) )
def _get_entity_from_string(self, string): def _load_entity_from_string(self, string):
"""Gets an entity from the given string, which may be a phone or
an username, and processes all the found entities on the session.
""" """
phone = EntityDatabase.parse_phone(string) Loads an entity from the given string, which may be a phone or
an username, and processes all the found entities on the session.
This method will effectively add the found users to the session
database, so it can be queried later.
May return a channel or chat if the string was an invite.
"""
phone = utils.parse_phone(string)
if phone: if phone:
entity = phone
self(GetContactsRequest(0)) self(GetContactsRequest(0))
else: else:
entity, is_join_chat = EntityDatabase.parse_username(string) entity, is_join_chat = utils.parse_username(string)
if is_join_chat: if is_join_chat:
invite = self(CheckChatInviteRequest(entity)) invite = self(CheckChatInviteRequest(entity))
if isinstance(invite, ChatInvite): if isinstance(invite, ChatInvite):
@ -1063,13 +1060,6 @@ class TelegramClient(TelegramBareClient):
return invite.chat return invite.chat
else: else:
self(ResolveUsernameRequest(entity)) self(ResolveUsernameRequest(entity))
# MtProtoSender will call .process_entities on the requests made
try:
return self.session.entities[entity]
except KeyError:
raise ValueError(
'Could not find user with username {}'.format(entity)
)
def get_input_entity(self, peer): def get_input_entity(self, peer):
""" """
@ -1092,12 +1082,15 @@ class TelegramClient(TelegramBareClient):
""" """
try: try:
# First try to get the entity from cache, otherwise figure it out # First try to get the entity from cache, otherwise figure it out
return self.session.entities.get_input_entity(peer) return self.session.get_input_entity(peer)
except KeyError: except KeyError:
pass pass
if isinstance(peer, str): if isinstance(peer, str):
return utils.get_input_peer(self._get_entity_from_string(peer)) invite = self._load_entity_from_string(peer)
if invite:
return utils.get_input_peer(invite)
return self.session.get_input_entity(peer)
is_peer = False is_peer = False
if isinstance(peer, int): if isinstance(peer, int):
@ -1130,7 +1123,7 @@ class TelegramClient(TelegramBareClient):
exclude_pinned=True exclude_pinned=True
)) ))
try: try:
return self.session.entities.get_input_entity(peer) return self.session.get_input_entity(peer)
except KeyError: except KeyError:
pass pass

View File

@ -1,252 +0,0 @@
import re
from threading import Lock
from ..tl import TLObject
from ..tl.types import (
User, Chat, Channel, PeerUser, PeerChat, PeerChannel,
InputPeerUser, InputPeerChat, InputPeerChannel
)
from .. import utils # Keep this line the last to maybe fix #357
USERNAME_RE = re.compile(
r'@|(?:https?://)?(?:telegram\.(?:me|dog)|t\.me)/(joinchat/)?'
)
class EntityDatabase:
def __init__(self, input_list=None, enabled=True, enabled_full=True):
"""Creates a new entity database with an initial load of "Input"
entities, if any.
If 'enabled', input entities will be saved. The whole entity
will be saved if both 'enabled' and 'enabled_full' are True.
"""
self.enabled = enabled
self.enabled_full = enabled_full
self._lock = Lock()
self._entities = {} # marked_id: user|chat|channel
if input_list:
# TODO For compatibility reasons some sessions were saved with
# 'access_hash': null in the JSON session file. Drop these, as
# it means we don't have access to such InputPeers. Issue #354.
self._input_entities = {
k: v for k, v in input_list if v is not None
}
else:
self._input_entities = {} # marked_id: hash
# TODO Allow disabling some extra mappings
self._username_id = {} # username: marked_id
self._phone_id = {} # phone: marked_id
def process(self, tlobject):
"""Processes all the found entities on the given TLObject,
unless .enabled is False.
Returns True if new input entities were added.
"""
if not self.enabled:
return False
# Save all input entities we know of
if not isinstance(tlobject, TLObject) and hasattr(tlobject, '__iter__'):
# This may be a list of users already for instance
return self.expand(tlobject)
entities = []
if hasattr(tlobject, 'chats') and hasattr(tlobject.chats, '__iter__'):
entities.extend(tlobject.chats)
if hasattr(tlobject, 'users') and hasattr(tlobject.users, '__iter__'):
entities.extend(tlobject.users)
return self.expand(entities)
def expand(self, entities):
"""Adds new input entities to the local database unconditionally.
Unknown types will be ignored.
"""
if not entities or not self.enabled:
return False
new = [] # Array of entities (User, Chat, or Channel)
new_input = {} # Dictionary of {entity_marked_id: access_hash}
for e in entities:
if not isinstance(e, TLObject):
continue
try:
p = utils.get_input_peer(e, allow_self=False)
marked_id = utils.get_peer_id(p, add_mark=True)
has_hash = False
if isinstance(p, InputPeerChat):
# Chats don't have a hash
new_input[marked_id] = 0
has_hash = True
elif p.access_hash:
# Some users and channels seem to be returned without
# an 'access_hash', meaning Telegram doesn't want you
# to access them. This is the reason behind ensuring
# that the 'access_hash' is non-zero. See issue #354.
new_input[marked_id] = p.access_hash
has_hash = True
if self.enabled_full and has_hash:
if isinstance(e, (User, Chat, Channel)):
new.append(e)
except ValueError:
pass
with self._lock:
before = len(self._input_entities)
self._input_entities.update(new_input)
for e in new:
self._add_full_entity(e)
return len(self._input_entities) != before
def _add_full_entity(self, entity):
"""Adds a "full" entity (User, Chat or Channel, not "Input*"),
despite the value of self.enabled and self.enabled_full.
Not to be confused with UserFull, ChatFull, or ChannelFull,
"full" means simply not "Input*".
"""
marked_id = utils.get_peer_id(
utils.get_input_peer(entity, allow_self=False), add_mark=True
)
try:
old_entity = self._entities[marked_id]
old_entity.__dict__.update(entity.__dict__) # Keep old references
# Update must delete old username and phone
username = getattr(old_entity, 'username', None)
if username:
del self._username_id[username.lower()]
phone = getattr(old_entity, 'phone', None)
if phone:
del self._phone_id[phone]
except KeyError:
# Add new entity
self._entities[marked_id] = entity
# Always update username or phone if any
username = getattr(entity, 'username', None)
if username:
self._username_id[username.lower()] = marked_id
phone = getattr(entity, 'phone', None)
if phone:
self._phone_id[phone] = marked_id
def _parse_key(self, key):
"""Parses the given string, integer or TLObject key into a
marked user ID ready for use on self._entities.
If a callable key is given, the entity will be passed to the
function, and if it returns a true-like value, the marked ID
for such entity will be returned.
Raises ValueError if it cannot be parsed.
"""
if isinstance(key, str):
phone = EntityDatabase.parse_phone(key)
try:
if phone:
return self._phone_id[phone]
else:
username, _ = EntityDatabase.parse_username(key)
return self._username_id[username.lower()]
except KeyError as e:
raise ValueError() from e
if isinstance(key, int):
return key # normal IDs are assumed users
if isinstance(key, TLObject):
return utils.get_peer_id(key, add_mark=True)
if callable(key):
for k, v in self._entities.items():
if key(v):
return k
raise ValueError()
def __getitem__(self, key):
"""See the ._parse_key() docstring for possible values of the key"""
try:
return self._entities[self._parse_key(key)]
except (ValueError, KeyError) as e:
raise KeyError(key) from e
def __delitem__(self, key):
try:
old = self._entities.pop(self._parse_key(key))
# Try removing the username and phone (if pop didn't fail),
# since the entity may have no username or phone, just ignore
# errors. It should be there if we popped the entity correctly.
try:
del self._username_id[getattr(old, 'username', None)]
except KeyError:
pass
try:
del self._phone_id[getattr(old, 'phone', None)]
except KeyError:
pass
except (ValueError, KeyError) as e:
raise KeyError(key) from e
@staticmethod
def parse_phone(phone):
"""Parses the given phone, or returns None if it's invalid"""
if isinstance(phone, int):
return str(phone)
else:
phone = re.sub(r'[+()\s-]', '', str(phone))
if phone.isdigit():
return phone
@staticmethod
def parse_username(username):
"""Parses the given username or channel access hash, given
a string, username or URL. Returns a tuple consisting of
both the stripped username and whether it is a joinchat/ hash.
"""
username = username.strip()
m = USERNAME_RE.match(username)
if m:
return username[m.end():], bool(m.group(1))
else:
return username, False
def get_input_entity(self, peer):
try:
i = utils.get_peer_id(peer, add_mark=True)
h = self._input_entities[i] # we store the IDs marked
i, k = utils.resolve_id(i) # removes the mark and returns kind
if k == PeerUser:
return InputPeerUser(i, h)
elif k == PeerChat:
return InputPeerChat(i)
elif k == PeerChannel:
return InputPeerChannel(i, h)
except ValueError as e:
raise KeyError(peer) from e
raise KeyError(peer)
def get_input_list(self):
return list(self._input_entities.items())
def clear(self, target=None):
if target is None:
self._entities.clear()
else:
del self[target]

View File

@ -8,8 +8,12 @@ from base64 import b64decode
from os.path import isfile as file_exists from os.path import isfile as file_exists
from threading import Lock from threading import Lock
from .entity_database import EntityDatabase from .. import utils, helpers
from .. import helpers from ..tl import TLObject
from ..tl.types import (
PeerUser, PeerChat, PeerChannel,
InputPeerUser, InputPeerChat, InputPeerChannel
)
EXTENSION = '.session' EXTENSION = '.session'
CURRENT_VERSION = 1 # database version CURRENT_VERSION = 1 # database version
@ -75,10 +79,9 @@ class Session:
self._auth_key = None self._auth_key = None
self._layer = 0 self._layer = 0
self._salt = 0 # Signed long self._salt = 0 # Signed long
self.entities = EntityDatabase() # Known and cached entities
# Migrating from .json -> SQL # Migrating from .json -> SQL
self._check_migrate_json() entities = self._check_migrate_json()
self._conn = sqlite3.connect(self.filename, check_same_thread=False) self._conn = sqlite3.connect(self.filename, check_same_thread=False)
c = self._conn.cursor() c = self._conn.cursor()
@ -114,14 +117,20 @@ class Session:
) )
c.execute( c.execute(
"""create table entities ( """create table entities (
id integer, id integer primary key,
hash integer, hash integer not null,
username text, username text,
phone integer, phone integer,
name text name text
)""" )"""
) )
c.execute("insert into version values (1)") c.execute("insert into version values (1)")
# Migrating from JSON -> new table and may have entities
if entities:
c.executemany(
'insert or replace into entities values (?,?,?,?,?)',
entities
)
c.close() c.close()
self.save() self.save()
@ -130,6 +139,8 @@ class Session:
try: try:
with open(self.filename, encoding='utf-8') as f: with open(self.filename, encoding='utf-8') as f:
data = json.load(f) data = json.load(f)
self.delete() # Delete JSON file to create database
self._port = data.get('port', self._port) self._port = data.get('port', self._port)
self._salt = data.get('salt', self._salt) self._salt = data.get('salt', self._salt)
# Keep while migrating from unsigned to signed salt # Keep while migrating from unsigned to signed salt
@ -146,10 +157,12 @@ class Session:
key = b64decode(data['auth_key_data']) key = b64decode(data['auth_key_data'])
self._auth_key = AuthKey(data=key) self._auth_key = AuthKey(data=key)
self.entities = EntityDatabase(data.get('entities', [])) rows = []
self.delete() # Delete JSON file to create database for p_id, p_hash in data.get('entities', []):
rows.append((p_id, p_hash, None, None, None))
return rows
except (UnicodeDecodeError, json.decoder.JSONDecodeError): except (UnicodeDecodeError, json.decoder.JSONDecodeError):
pass return [] # No entities
def _upgrade_database(self, old): def _upgrade_database(self, old):
pass pass
@ -275,9 +288,103 @@ class Session:
correct = correct_msg_id >> 32 correct = correct_msg_id >> 32
self.time_offset = correct - now self.time_offset = correct - now
def process_entities(self, tlobject): # Entity processing
try:
if self.entities.process(tlobject): def process_entities(self, tlo):
self.save() # Save if any new entities got added """Processes all the found entities on the given TLObject,
except: unless .enabled is False.
pass
Returns True if new input entities were added.
"""
if not self.save_entities:
return
if not isinstance(tlo, TLObject) and hasattr(tlo, '__iter__'):
# This may be a list of users already for instance
entities = tlo
else:
entities = []
if hasattr(tlo, 'chats') and hasattr(tlo.chats, '__iter__'):
entities.extend(tlo.chats)
if hasattr(tlo, 'users') and hasattr(tlo.users, '__iter__'):
entities.extend(tlo.users)
if not entities:
return
rows = [] # Rows to add (id, hash, username, phone, name)
for e in entities:
if not isinstance(e, TLObject):
continue
try:
p = utils.get_input_peer(e, allow_self=False)
marked_id = utils.get_peer_id(p, add_mark=True)
p_hash = None
if isinstance(p, InputPeerChat):
p_hash = 0
elif p.access_hash:
# Some users and channels seem to be returned without
# an 'access_hash', meaning Telegram doesn't want you
# to access them. This is the reason behind ensuring
# that the 'access_hash' is non-zero. See issue #354.
p_hash = p.access_hash
if p_hash is not None:
username = getattr(e, 'username', None)
phone = getattr(e, 'phone', None)
name = utils.get_display_name(e) or None
rows.append((marked_id, p_hash, username, phone, name))
except ValueError:
pass
if not rows:
return
with self._db_lock:
self._conn.executemany(
'insert or replace into entities values (?,?,?,?,?)', rows
)
self.save()
def get_input_entity(self, key):
"""Parses the given string, integer or TLObject key into a
marked entity ID, which is then used to fetch the hash
from the database.
If a callable key is given, every row will be fetched,
and passed as a tuple to a function, that should return
a true-like value when the desired row is found.
Raises ValueError if it cannot be found.
"""
c = self._conn.cursor()
if isinstance(key, str):
phone = utils.parse_phone(key)
if phone:
c.execute('select id, hash from entities where phone=?',
(phone,))
else:
username, _ = utils.parse_username(key)
c.execute('select id, hash from entities where username=?',
(username,))
if isinstance(key, TLObject):
# crc32(b'InputPeer') and crc32(b'Peer')
if type(key).SUBCLASS_OF_ID == 0xc91c90b6:
return key
key = utils.get_peer_id(key, add_mark=True)
if isinstance(key, int):
c.execute('select id, hash from entities where id=?', (key,))
result = c.fetchone()
if result:
i, h = result # unpack resulting tuple
i, k = utils.resolve_id(i) # removes the mark and returns kind
if k == PeerUser:
return InputPeerUser(i, h)
elif k == PeerChat:
return InputPeerChat(i)
elif k == PeerChannel:
return InputPeerChannel(i, h)
else:
raise ValueError('Could not find input entity with key ', key)

View File

@ -5,6 +5,8 @@ to convert between an entity like an User, Chat, etc. into its Input version)
import math import math
from mimetypes import add_type, guess_extension from mimetypes import add_type, guess_extension
import re
from .tl import TLObject from .tl import TLObject
from .tl.types import ( from .tl.types import (
Channel, ChannelForbidden, Chat, ChatEmpty, ChatForbidden, ChatFull, Channel, ChannelForbidden, Chat, ChatEmpty, ChatForbidden, ChatFull,
@ -24,6 +26,11 @@ from .tl.types import (
) )
USERNAME_RE = re.compile(
r'@|(?:https?://)?(?:telegram\.(?:me|dog)|t\.me)/(joinchat/)?'
)
def get_display_name(entity): def get_display_name(entity):
"""Gets the input peer for the given "entity" (user, chat or channel) """Gets the input peer for the given "entity" (user, chat or channel)
Returns None if it was not found""" Returns None if it was not found"""
@ -305,6 +312,29 @@ def get_input_media(media, user_caption=None, is_photo=False):
_raise_cast_fail(media, 'InputMedia') _raise_cast_fail(media, 'InputMedia')
def parse_phone(phone):
"""Parses the given phone, or returns None if it's invalid"""
if isinstance(phone, int):
return str(phone)
else:
phone = re.sub(r'[+()\s-]', '', str(phone))
if phone.isdigit():
return phone
def parse_username(username):
"""Parses the given username or channel access hash, given
a string, username or URL. Returns a tuple consisting of
both the stripped username and whether it is a joinchat/ hash.
"""
username = username.strip()
m = USERNAME_RE.match(username)
if m:
return username[m.end():], bool(m.group(1))
else:
return username, False
def get_peer_id(peer, add_mark=False): def get_peer_id(peer, add_mark=False):
"""Finds the ID of the given peer, and optionally converts it to """Finds the ID of the given peer, and optionally converts it to
the "bot api" format if 'add_mark' is set to True. the "bot api" format if 'add_mark' is set to True.