Allow EntityDatabase to be accessed by phone

This commit is contained in:
Lonami Exo 2017-10-05 13:01:00 +02:00
parent a8edacd34a
commit 99cc0778bb
2 changed files with 43 additions and 41 deletions

View File

@ -1,7 +1,5 @@
import os import os
import re
from datetime import datetime, timedelta from datetime import datetime, timedelta
from functools import lru_cache
from mimetypes import guess_type from mimetypes import guess_type
try: try:
@ -17,6 +15,7 @@ from .errors import (
) )
from .network import ConnectionMode from .network import ConnectionMode
from .tl import TLObject from .tl import TLObject
from .tl.entity_database import EntityDatabase
from .tl.functions.account import ( from .tl.functions.account import (
GetPasswordRequest GetPasswordRequest
) )
@ -127,7 +126,7 @@ class TelegramClient(TelegramBareClient):
def send_code_request(self, phone): def send_code_request(self, phone):
"""Sends a code request to the specified phone number""" """Sends a code request to the specified phone number"""
phone = self._parse_phone(phone) phone = EntityDatabase.parse_phone(phone) or self._phone
result = self(SendCodeRequest(phone, self.api_id, self.api_hash)) result = self(SendCodeRequest(phone, self.api_id, self.api_hash))
self._phone = phone self._phone = phone
self._phone_code_hash = result.phone_code_hash self._phone_code_hash = result.phone_code_hash
@ -160,7 +159,7 @@ class TelegramClient(TelegramBareClient):
if phone and not code: if phone and not code:
return self.send_code_request(phone) return self.send_code_request(phone)
elif code: elif code:
phone = self._parse_phone(phone) phone = EntityDatabase.parse_phone(phone) or self._phone
phone_code_hash = phone_code_hash or self._phone_code_hash phone_code_hash = phone_code_hash or self._phone_code_hash
if not phone: if not phone:
raise ValueError( raise ValueError(
@ -897,40 +896,20 @@ class TelegramClient(TelegramBareClient):
"""Gets an entity from the given string, which may be a phone or """Gets an entity from the given string, which may be a phone or
an username, and processes all the found entities on the session. an username, and processes all the found entities on the session.
""" """
stripped_phone = self._parse_phone(string, ignore_saved=True) phone = EntityDatabase.parse_phone(string)
if stripped_phone.isdigit(): if phone:
contacts = self(GetContactsRequest(0)) entity = phone
self.session.process_entities(contacts) self.session.process_entities(self(GetContactsRequest(0)))
try:
return next(
u for u in contacts.users
if u.phone and u.phone.endswith(stripped_phone)
)
except StopIteration:
raise ValueError(
'Could not find user with phone {}, '
'add them to your contacts first'.format(string)
)
else: else:
entity = string.strip('@').lower() entity = string.strip('@').lower()
self.session.process_entities(self(ResolveUsernameRequest(entity))) self.session.process_entities(self(ResolveUsernameRequest(entity)))
try:
return self.session.entities[entity]
except KeyError:
raise ValueError(
'Could not find user with username {}'.format(entity)
)
def _parse_phone(self, phone, ignore_saved=False): try:
if isinstance(phone, int): return self.session.entities[entity]
phone = str(phone) except KeyError:
elif phone: raise ValueError(
phone = re.sub(r'[+()\s-]', '', phone) 'Could not find user with username {}'.format(entity)
)
if ignore_saved:
return phone
else:
return phone or self._phone
def get_input_entity(self, peer): def get_input_entity(self, peer):
"""Gets the input entity given its PeerUser, PeerChat, PeerChannel. """Gets the input entity given its PeerUser, PeerChat, PeerChannel.
@ -940,7 +919,7 @@ class TelegramClient(TelegramBareClient):
If this Peer hasn't been seen before by the library, all dialogs If this Peer hasn't been seen before by the library, all dialogs
will loaded, and their entities saved to the session file. will loaded, and their entities saved to the session file.
If even after If even after it's not found, a ValueError is raised.
""" """
try: try:
# First try to get the entity from cache, otherwise figure it out # First try to get the entity from cache, otherwise figure it out

View File

@ -1,5 +1,7 @@
from threading import Lock from threading import Lock
import re
from .. import utils from .. import utils
from ..tl import TLObject from ..tl import TLObject
from ..tl.types import User, Chat, Channel from ..tl.types import User, Chat, Channel
@ -19,6 +21,7 @@ class EntityDatabase:
# TODO Allow disabling some extra mappings # TODO Allow disabling some extra mappings
self._username_id = {} # username: marked_id self._username_id = {} # username: marked_id
self._phone_id = {} # phone: marked_id
def process(self, tlobject): def process(self, tlobject):
"""Processes all the found entities on the given TLObject, """Processes all the found entities on the given TLObject,
@ -87,19 +90,27 @@ class EntityDatabase:
old_entity = self._entities[marked_id] old_entity = self._entities[marked_id]
old_entity.__dict__.update(entity.__dict__) # Keep old references old_entity.__dict__.update(entity.__dict__) # Keep old references
# Update must delete old username # Update must delete old username and phone
username = getattr(old_entity, 'username', None) username = getattr(old_entity, 'username', None)
if username: if username:
del self._username_id[username.lower()] del self._username_id[username.lower()]
phone = getattr(old_entity, 'phone', None)
if phone:
del self._phone_id[phone]
except KeyError: except KeyError:
# Add new entity # Add new entity
self._entities[marked_id] = entity self._entities[marked_id] = entity
# Always update username if any # Always update username or phone if any
username = getattr(entity, 'username', None) username = getattr(entity, 'username', None)
if username: if username:
self._username_id[username.lower()] = marked_id self._username_id[username.lower()] = marked_id
phone = getattr(entity, 'phone', None)
if phone:
self._username_id[phone] = marked_id
def __getitem__(self, key): def __getitem__(self, key):
"""Accepts a digit only string as phone number, """Accepts a digit only string as phone number,
otherwise it's treated as an username. otherwise it's treated as an username.
@ -113,10 +124,12 @@ class EntityDatabase:
Note that megagroups are channels with .megagroup = True. Note that megagroups are channels with .megagroup = True.
""" """
if isinstance(key, str): if isinstance(key, str):
# TODO Parse phone properly, currently only usernames phone = EntityDatabase.parse_phone(key)
key = key.lstrip('@').lower() if phone:
# TODO Use the client to return from username if not found return self._phone_id[phone]
return self._entities[self._username_id[key]] else:
key = key.lstrip('@').lower()
return self._entities[self._username_id[key]]
if isinstance(key, int): if isinstance(key, int):
return self._entities[key] # normal IDs are assumed users return self._entities[key] # normal IDs are assumed users
@ -140,6 +153,16 @@ class EntityDatabase:
# TODO Allow search by name by tokenizing the input and return a list # TODO Allow search by name by tokenizing the input and return a list
@staticmethod
def parse_phone(phone):
"""Parses the given phone, or returns None if it's invalid"""
if isinstance(phone, int):
return str(phone)
else:
phone = re.sub(r'[+()\s-]', '', str(phone))
if phone.isdigit():
return phone
def get_input_entity(self, peer): def get_input_entity(self, peer):
try: try:
return self._input_entities[utils.get_peer_id(peer, add_mark=True)] return self._input_entities[utils.get_peer_id(peer, add_mark=True)]