Save all found entities to the session file

This commit is contained in:
Lonami Exo 2017-10-01 13:24:04 +02:00
parent 0a4a898c49
commit a737f33324
3 changed files with 106 additions and 8 deletions

View File

@ -493,10 +493,22 @@ class TelegramBareClient:
# "A container may only be accepted or
# rejected by the other party as a whole."
return None
elif len(requests) == 1:
return requests[0].result
else:
return [x.result for x in requests]
# Save all input entities we know of
entities = []
results = []
for x in requests:
y = x.result
results.append(y)
if hasattr(y, 'chats') and hasattr(y.chats, '__iter__'):
entities.extend(y.chats)
if hasattr(y, 'users') and hasattr(y.users, '__iter__'):
entities.extend(y.users)
if self.session.add_entities(entities):
self.session.save() # Save if any new entities got added
return results[0] if len(results) == 1 else results
except (PhoneMigrateError, NetworkMigrateError,
UserMigrateError) as e:

View File

@ -2,11 +2,15 @@ import json
import os
import platform
import time
from threading import Lock
from base64 import b64encode, b64decode
from os.path import isfile as file_exists
from threading import Lock
from .. import helpers as utils
from .. import helpers, utils
from ..tl.types import (
InputPeerUser, InputPeerChat, InputPeerChannel,
PeerUser, PeerChat, PeerChannel
)
class Session:
@ -51,7 +55,7 @@ class Session:
# Cross-thread safety
self._lock = Lock()
self.id = utils.generate_random_long(signed=False)
self.id = helpers.generate_random_long(signed=False)
self._sequence = 0
self.time_offset = 0
self._last_msg_id = 0 # Long
@ -62,6 +66,8 @@ class Session:
self.auth_key = None
self.layer = 0
self.salt = 0 # Unsigned long
self._input_entities = {} # {marked_id: hash}
self._entities_lock = Lock()
def save(self):
"""Saves the current session object as session_user_id.session"""
@ -74,7 +80,8 @@ class Session:
'server_address': self.server_address,
'auth_key_data':
b64encode(self.auth_key.key).decode('ascii')
if self.auth_key else None
if self.auth_key else None,
'entities': list(self._input_entities.items())
}, file)
def delete(self):
@ -122,6 +129,9 @@ class Session:
key = b64decode(data['auth_key_data'])
result.auth_key = AuthKey(data=key)
for e_mid, e_hash in data.get('entities', []):
result._input_entities[e_mid] = e_hash
except (json.decoder.JSONDecodeError, UnicodeDecodeError):
pass
@ -164,3 +174,43 @@ class Session:
now = int(time.time())
correct = correct_msg_id >> 32
self.time_offset = correct - now
def add_entities(self, entities):
"""Adds new input entities to the local database of them.
Unknown types will be ignored.
"""
if not entities:
return False
new = {}
for e in entities:
try:
p = utils.get_input_peer(e)
new[utils.get_peer_id(p, add_mark=True)] = \
getattr(p, 'access_hash', 0) # chats won't have hash
except ValueError:
pass
with self._entities_lock:
before = len(self._input_entities)
self._input_entities.update(new)
return len(self._input_entities) != before
def get_input_entity(self, peer):
"""Gets an input entity known its Peer or a marked ID,
or raises KeyError if not found/invalid.
"""
if not isinstance(peer, int):
peer = utils.get_peer_id(peer, add_mark=True)
entity_hash = self._input_entities[peer]
entity_id, peer_class = utils.resolve_id(peer)
if peer_class == PeerUser:
return InputPeerUser(entity_id, entity_hash)
if peer_class == PeerChat:
return InputPeerChat(entity_id)
if peer_class == PeerChannel:
return InputPeerChannel(entity_id, entity_hash)
raise KeyError()

View File

@ -2,6 +2,7 @@
Utilities for working with the Telegram API itself (such as handy methods
to convert between an entity like an User, Chat, etc. into its Input version)
"""
import math
from mimetypes import add_type, guess_extension
from .tl import TLObject
@ -299,6 +300,41 @@ def get_input_media(media, user_caption=None, is_photo=False):
_raise_cast_fail(media, 'InputMedia')
def get_peer_id(peer, add_mark=False):
"""Finds the ID of the given peer, and optionally converts it to
the "bot api" format if 'add_mark' is set to True.
"""
if isinstance(peer, PeerUser) or isinstance(peer, InputPeerUser):
return peer.user_id
else:
if isinstance(peer, PeerChat) or isinstance(peer, InputPeerChat):
if not add_mark:
return peer.chat_id
# Chats are marked by turning them into negative numbers
return -peer.chat_id
elif isinstance(peer, PeerChannel) or isinstance(peer, InputPeerChannel):
if not add_mark:
return peer.channel_id
# Prepend -100 through math tricks (.to_supergroup() on Madeline)
i = peer.channel_id # IDs will be strictly positive -> log works
return -(i + pow(10, math.floor(math.log10(i) + 3)))
_raise_cast_fail(peer, 'int')
def resolve_id(marked_id):
"""Given a marked ID, returns the original ID and its Peer type"""
if marked_id >= 0:
return marked_id, PeerUser
if str(marked_id).startswith('-100'):
return int(str(marked_id)[4:]), PeerChannel
return -marked_id, PeerChat
def find_user_or_chat(peer, users, chats):
"""Finds the corresponding user or chat given a peer.
Returns None if it was not found"""