Split Session into three parts and make a module for sessions

This commit is contained in:
Tulir Asokan 2018-03-01 23:34:32 +02:00
parent f09ab6c6b6
commit c5e6f7e265
6 changed files with 491 additions and 193 deletions

View File

@ -402,13 +402,13 @@ class MtProtoSender:
elif bad_msg.error_code == 32:
# msg_seqno too low, so just pump it up by some "large" amount
# TODO A better fix would be to start with a new fresh session ID
self.session._sequence += 64
self.session.sequence += 64
__log__.info('Attempting to set the right higher sequence')
self._resend_request(bad_msg.bad_msg_id)
return True
elif bad_msg.error_code == 33:
# msg_seqno too high never seems to happen but just in case
self.session._sequence -= 16
self.session.sequence -= 16
__log__.info('Attempting to set the right lower sequence')
self._resend_request(bad_msg.bad_msg_id)
return True

View File

@ -0,0 +1,3 @@
from .abstract import Session
from .memory import MemorySession
from .sqlite import SQLiteSession

View File

@ -0,0 +1,136 @@
from abc import ABC, abstractmethod
class Session(ABC):
@abstractmethod
def clone(self):
raise NotImplementedError
@abstractmethod
def set_dc(self, dc_id, server_address, port):
raise NotImplementedError
@property
@abstractmethod
def server_address(self):
raise NotImplementedError
@property
@abstractmethod
def port(self):
raise NotImplementedError
@property
@abstractmethod
def auth_key(self):
raise NotImplementedError
@auth_key.setter
@abstractmethod
def auth_key(self, value):
raise NotImplementedError
@property
@abstractmethod
def time_offset(self):
raise NotImplementedError
@time_offset.setter
@abstractmethod
def time_offset(self, value):
raise NotImplementedError
@property
@abstractmethod
def salt(self):
raise NotImplementedError
@salt.setter
@abstractmethod
def salt(self, value):
raise NotImplementedError
@property
@abstractmethod
def device_model(self):
raise NotImplementedError
@property
@abstractmethod
def system_version(self):
raise NotImplementedError
@property
@abstractmethod
def app_version(self):
raise NotImplementedError
@property
@abstractmethod
def lang_code(self):
raise NotImplementedError
@property
@abstractmethod
def system_lang_code(self):
raise NotImplementedError
@property
@abstractmethod
def report_errors(self):
raise NotImplementedError
@property
@abstractmethod
def sequence(self):
raise NotImplementedError
@property
@abstractmethod
def flood_sleep_threshold(self):
raise NotImplementedError
@abstractmethod
def close(self):
raise NotImplementedError
@abstractmethod
def save(self):
raise NotImplementedError
@abstractmethod
def delete(self):
raise NotImplementedError
@classmethod
@abstractmethod
def list_sessions(cls):
raise NotImplementedError
@abstractmethod
def get_new_msg_id(self):
raise NotImplementedError
@abstractmethod
def update_time_offset(self, correct_msg_id):
raise NotImplementedError
@abstractmethod
def generate_sequence(self, content_related):
raise NotImplementedError
@abstractmethod
def process_entities(self, tlo):
raise NotImplementedError
@abstractmethod
def get_input_entity(self, key):
raise NotImplementedError
@abstractmethod
def cache_file(self, md5_digest, file_size, instance):
raise NotImplementedError
@abstractmethod
def get_file(self, md5_digest, file_size, cls):
raise NotImplementedError

297
telethon/sessions/memory.py Normal file
View File

@ -0,0 +1,297 @@
from enum import Enum
import time
import platform
from .. import utils
from .abstract import Session
from ..tl import TLObject
from ..tl.types import (
PeerUser, PeerChat, PeerChannel,
InputPeerUser, InputPeerChat, InputPeerChannel,
InputPhoto, InputDocument
)
class _SentFileType(Enum):
DOCUMENT = 0
PHOTO = 1
@staticmethod
def from_type(cls):
if cls == InputDocument:
return _SentFileType.DOCUMENT
elif cls == InputPhoto:
return _SentFileType.PHOTO
else:
raise ValueError('The cls must be either InputDocument/InputPhoto')
class MemorySession(Session):
def __init__(self):
self._dc_id = None
self._server_address = None
self._port = None
self._salt = None
self._auth_key = None
self._sequence = 0
self._last_msg_id = 0
self._time_offset = 0
self._flood_sleep_threshold = 60
system = platform.uname()
self._device_model = system.system or 'Unknown'
self._system_version = system.release or '1.0'
self._app_version = '1.0'
self._lang_code = 'en'
self._system_lang_code = self.lang_code
self._report_errors = True
self._flood_sleep_threshold = 60
self._files = {}
self._entities = set()
def clone(self):
cloned = MemorySession()
cloned._device_model = self.device_model
cloned._system_version = self.system_version
cloned._app_version = self.app_version
cloned._lang_code = self.lang_code
cloned._system_lang_code = self.system_lang_code
cloned._report_errors = self.report_errors
cloned._flood_sleep_threshold = self.flood_sleep_threshold
def set_dc(self, dc_id, server_address, port):
self._dc_id = dc_id
self._server_address = server_address
self._port = port
@property
def server_address(self):
return self._server_address
@property
def port(self):
return self._port
@property
def auth_key(self):
return self._auth_key
@auth_key.setter
def auth_key(self, value):
self._auth_key = value
@property
def time_offset(self):
return self._time_offset
@time_offset.setter
def time_offset(self, value):
self._time_offset = value
@property
def salt(self):
return self._salt
@salt.setter
def salt(self, value):
self._salt = value
@property
def device_model(self):
return self._device_model
@property
def system_version(self):
return self._system_version
@property
def app_version(self):
return self._app_version
@property
def lang_code(self):
return self._lang_code
@property
def system_lang_code(self):
return self._system_lang_code
@property
def report_errors(self):
return self._report_errors
@property
def sequence(self):
return self._sequence
@property
def flood_sleep_threshold(self):
return self._flood_sleep_threshold
def close(self):
pass
def save(self):
pass
def delete(self):
pass
@classmethod
def list_sessions(cls):
raise NotImplementedError
def get_new_msg_id(self):
"""Generates a new unique message ID based on the current
time (in ms) since epoch"""
now = time.time() + self._time_offset
nanoseconds = int((now - int(now)) * 1e+9)
new_msg_id = (int(now) << 32) | (nanoseconds << 2)
if self._last_msg_id >= new_msg_id:
new_msg_id = self._last_msg_id + 4
self._last_msg_id = new_msg_id
return new_msg_id
def update_time_offset(self, correct_msg_id):
now = int(time.time())
correct = correct_msg_id >> 32
self._time_offset = correct - now
self._last_msg_id = 0
def generate_sequence(self, content_related):
if content_related:
result = self._sequence * 2 + 1
self._sequence += 1
return result
else:
return self._sequence * 2
@staticmethod
def _entities_to_rows(tlo):
if not isinstance(tlo, TLObject) and utils.is_list_like(tlo):
# This may be a list of users already for instance
entities = tlo
else:
entities = []
if hasattr(tlo, 'chats') and utils.is_list_like(tlo.chats):
entities.extend(tlo.chats)
if hasattr(tlo, 'users') and utils.is_list_like(tlo.users):
entities.extend(tlo.users)
if not entities:
return
rows = [] # Rows to add (id, hash, username, phone, name)
for e in entities:
if not isinstance(e, TLObject):
continue
try:
p = utils.get_input_peer(e, allow_self=False)
marked_id = utils.get_peer_id(p)
except ValueError:
continue
if isinstance(p, (InputPeerUser, InputPeerChannel)):
if not p.access_hash:
# Some users and channels seem to be returned without
# an 'access_hash', meaning Telegram doesn't want you
# to access them. This is the reason behind ensuring
# that the 'access_hash' is non-zero. See issue #354.
# Note that this checks for zero or None, see #392.
continue
else:
p_hash = p.access_hash
elif isinstance(p, InputPeerChat):
p_hash = 0
else:
continue
username = getattr(e, 'username', None) or None
if username is not None:
username = username.lower()
phone = getattr(e, 'phone', None)
name = utils.get_display_name(e) or None
rows.append((marked_id, p_hash, username, phone, name))
return rows
def process_entities(self, tlo):
self._entities += set(self._entities_to_rows(tlo))
def get_entity_rows_by_phone(self, phone):
rows = [(id, hash) for id, hash, _, found_phone, _
in self._entities if found_phone == phone]
return rows[0] if rows else None
def get_entity_rows_by_username(self, username):
rows = [(id, hash) for id, hash, found_username, _, _
in self._entities if found_username == username]
return rows[0] if rows else None
def get_entity_rows_by_name(self, name):
rows = [(id, hash) for id, hash, _, _, found_name
in self._entities if found_name == name]
return rows[0] if rows else None
def get_entity_rows_by_id(self, id):
rows = [(id, hash) for found_id, hash, _, _, _
in self._entities if found_id == id]
return rows[0] if rows else None
def get_input_entity(self, key):
try:
if key.SUBCLASS_OF_ID in (0xc91c90b6, 0xe669bf46, 0x40f202fd):
# hex(crc32(b'InputPeer', b'InputUser' and b'InputChannel'))
# We already have an Input version, so nothing else required
return key
# Try to early return if this key can be casted as input peer
return utils.get_input_peer(key)
except (AttributeError, TypeError):
# Not a TLObject or can't be cast into InputPeer
if isinstance(key, TLObject):
key = utils.get_peer_id(key)
result = None
if isinstance(key, str):
phone = utils.parse_phone(key)
if phone:
result = self.get_entity_rows_by_phone(phone)
else:
username, _ = utils.parse_username(key)
if username:
result = self.get_entity_rows_by_username(username)
if isinstance(key, int):
result = self.get_entity_rows_by_id(key)
if not result and isinstance(key, str):
result = self.get_entity_rows_by_name(key)
if result:
i, h = result # unpack resulting tuple
i, k = utils.resolve_id(i) # removes the mark and returns kind
if k == PeerUser:
return InputPeerUser(i, h)
elif k == PeerChat:
return InputPeerChat(i)
elif k == PeerChannel:
return InputPeerChannel(i, h)
else:
raise ValueError('Could not find input entity with key ', key)
def cache_file(self, md5_digest, file_size, instance):
if not isinstance(instance, (InputDocument, InputPhoto)):
raise TypeError('Cannot cache %s instance' % type(instance))
key = (md5_digest, file_size, _SentFileType.from_type(instance))
value = (instance.id, instance.access_hash)
self._files[key] = value
def get_file(self, md5_digest, file_size, cls):
key = (md5_digest, file_size, _SentFileType.from_type(cls))
try:
return self._files[key]
except KeyError:
return None

View File

@ -5,14 +5,15 @@ import sqlite3
import struct
import time
from base64 import b64decode
from enum import Enum
from os.path import isfile as file_exists
from threading import Lock, RLock
from . import utils
from .crypto import AuthKey
from .tl import TLObject
from .tl.types import (
from .. import utils
from .abstract import Session
from .memory import MemorySession, _SentFileType
from ..crypto import AuthKey
from ..tl import TLObject
from ..tl.types import (
PeerUser, PeerChat, PeerChannel,
InputPeerUser, InputPeerChat, InputPeerChannel,
InputPhoto, InputDocument
@ -22,21 +23,7 @@ EXTENSION = '.session'
CURRENT_VERSION = 3 # database version
class _SentFileType(Enum):
DOCUMENT = 0
PHOTO = 1
@staticmethod
def from_type(cls):
if cls == InputDocument:
return _SentFileType.DOCUMENT
elif cls == InputPhoto:
return _SentFileType.PHOTO
else:
raise ValueError('The cls must be either InputDocument/InputPhoto')
class Session:
class SQLiteSession(MemorySession):
"""This session contains the required information to login into your
Telegram account. NEVER give the saved JSON file to anyone, since
they would gain instant access to all your messages and contacts.
@ -44,7 +31,9 @@ class Session:
If you think the session has been compromised, close all the sessions
through an official Telegram client to revoke the authorization.
"""
def __init__(self, session_id):
super().__init__()
"""session_user_id should either be a string or another Session.
Note that if another session is given, only parameters like
those required to init a connection will be copied.
@ -54,15 +43,15 @@ class Session:
# For connection purposes
if isinstance(session_id, Session):
self.device_model = session_id.device_model
self.system_version = session_id.system_version
self.app_version = session_id.app_version
self.lang_code = session_id.lang_code
self.system_lang_code = session_id.system_lang_code
self.lang_pack = session_id.lang_pack
self.report_errors = session_id.report_errors
self.save_entities = session_id.save_entities
self.flood_sleep_threshold = session_id.flood_sleep_threshold
self._device_model = session_id.device_model
self._system_version = session_id.system_version
self._app_version = session_id.app_version
self._lang_code = session_id.lang_code
self._system_lang_code = session_id.system_lang_code
self._report_errors = session_id.report_errors
self._flood_sleep_threshold = session_id.flood_sleep_threshold
if isinstance(session_id, SQLiteSession):
self.save_entities = session_id.save_entities
else: # str / None
if session_id:
self.filename = session_id
@ -70,15 +59,14 @@ class Session:
self.filename += EXTENSION
system = platform.uname()
self.device_model = system.system or 'Unknown'
self.system_version = system.release or '1.0'
self.app_version = '1.0' # '0' will provoke error
self.lang_code = 'en'
self.system_lang_code = self.lang_code
self.lang_pack = ''
self.report_errors = True
self._device_model = system.system or 'Unknown'
self._system_version = system.release or '1.0'
self._app_version = '1.0' # '0' will provoke error
self._lang_code = 'en'
self._system_lang_code = self.lang_code
self._report_errors = True
self.save_entities = True
self.flood_sleep_threshold = 60
self._flood_sleep_threshold = 60
self.id = struct.unpack('q', os.urandom(8))[0]
self._sequence = 0
@ -163,6 +151,9 @@ class Session:
c.close()
self.save()
def clone(self):
return SQLiteSession(self)
def _check_migrate_json(self):
if file_exists(self.filename):
try:
@ -233,19 +224,7 @@ class Session:
self._auth_key = None
c.close()
@property
def server_address(self):
return self._server_address
@property
def port(self):
return self._port
@property
def auth_key(self):
return self._auth_key
@auth_key.setter
@Session.auth_key.setter
def auth_key(self, value):
self._auth_key = value
self._update_session_table()
@ -298,53 +277,14 @@ class Session:
except OSError:
return False
@staticmethod
def list_sessions():
@classmethod
def list_sessions(cls):
"""Lists all the sessions of the users who have ever connected
using this client and never logged out
"""
return [os.path.splitext(os.path.basename(f))[0]
for f in os.listdir('.') if f.endswith(EXTENSION)]
def generate_sequence(self, content_related):
"""Thread safe method to generates the next sequence number,
based on whether it was confirmed yet or not.
Note that if confirmed=True, the sequence number
will be increased by one too
"""
with self._seq_no_lock:
if content_related:
result = self._sequence * 2 + 1
self._sequence += 1
return result
else:
return self._sequence * 2
def get_new_msg_id(self):
"""Generates a new unique message ID based on the current
time (in ms) since epoch"""
# Refer to mtproto_plain_sender.py for the original method
now = time.time() + self.time_offset
nanoseconds = int((now - int(now)) * 1e+9)
# "message identifiers are divisible by 4"
new_msg_id = (int(now) << 32) | (nanoseconds << 2)
with self._msg_id_lock:
if self._last_msg_id >= new_msg_id:
new_msg_id = self._last_msg_id + 4
self._last_msg_id = new_msg_id
return new_msg_id
def update_time_offset(self, correct_msg_id):
"""Updates the time offset based on a known correct message ID"""
now = int(time.time())
correct = correct_msg_id >> 32
self.time_offset = correct - now
self._last_msg_id = 0
# Entity processing
def process_entities(self, tlo):
@ -356,49 +296,7 @@ class Session:
if not self.save_entities:
return
if not isinstance(tlo, TLObject) and utils.is_list_like(tlo):
# This may be a list of users already for instance
entities = tlo
else:
entities = []
if hasattr(tlo, 'chats') and utils.is_list_like(tlo.chats):
entities.extend(tlo.chats)
if hasattr(tlo, 'users') and utils.is_list_like(tlo.users):
entities.extend(tlo.users)
if not entities:
return
rows = [] # Rows to add (id, hash, username, phone, name)
for e in entities:
if not isinstance(e, TLObject):
continue
try:
p = utils.get_input_peer(e, allow_self=False)
marked_id = utils.get_peer_id(p)
except ValueError:
continue
if isinstance(p, (InputPeerUser, InputPeerChannel)):
if not p.access_hash:
# Some users and channels seem to be returned without
# an 'access_hash', meaning Telegram doesn't want you
# to access them. This is the reason behind ensuring
# that the 'access_hash' is non-zero. See issue #354.
# Note that this checks for zero or None, see #392.
continue
else:
p_hash = p.access_hash
elif isinstance(p, InputPeerChat):
p_hash = 0
else:
continue
username = getattr(e, 'username', None) or None
if username is not None:
username = username.lower()
phone = getattr(e, 'phone', None)
name = utils.get_display_name(e) or None
rows.append((marked_id, p_hash, username, phone, name))
rows = self._entities_to_rows(tlo)
if not rows:
return
@ -408,62 +306,26 @@ class Session:
)
self.save()
def get_input_entity(self, key):
"""Parses the given string, integer or TLObject key into a
marked entity ID, which is then used to fetch the hash
from the database.
If a callable key is given, every row will be fetched,
and passed as a tuple to a function, that should return
a true-like value when the desired row is found.
Raises ValueError if it cannot be found.
"""
try:
if key.SUBCLASS_OF_ID in (0xc91c90b6, 0xe669bf46, 0x40f202fd):
# hex(crc32(b'InputPeer', b'InputUser' and b'InputChannel'))
# We already have an Input version, so nothing else required
return key
# Try to early return if this key can be casted as input peer
return utils.get_input_peer(key)
except (AttributeError, TypeError):
# Not a TLObject or can't be cast into InputPeer
if isinstance(key, TLObject):
key = utils.get_peer_id(key)
def _fetchone_entity(self, query, args):
c = self._cursor()
if isinstance(key, str):
phone = utils.parse_phone(key)
if phone:
c.execute('select id, hash from entities where phone=?',
(phone,))
else:
username, _ = utils.parse_username(key)
if username:
c.execute('select id, hash from entities where username=?',
c.execute(query, args)
return c.fetchone()
def get_entity_rows_by_phone(self, phone):
return self._fetchone_entity(
'select id, hash from entities where phone=?', (phone,))
def get_entity_rows_by_username(self, username):
self._fetchone_entity('select id, hash from entities where username=?',
(username,))
if isinstance(key, int):
c.execute('select id, hash from entities where id=?', (key,))
def get_entity_rows_by_name(self, name):
self._fetchone_entity('select id, hash from entities where name=?',
(name,))
result = c.fetchone()
if not result and isinstance(key, str):
# Try exact match by name if phone/username failed
c.execute('select id, hash from entities where name=?', (key,))
result = c.fetchone()
c.close()
if result:
i, h = result # unpack resulting tuple
i, k = utils.resolve_id(i) # removes the mark and returns kind
if k == PeerUser:
return InputPeerUser(i, h)
elif k == PeerChat:
return InputPeerChat(i)
elif k == PeerChannel:
return InputPeerChannel(i, h)
else:
raise ValueError('Could not find input entity with key ', key)
def get_entity_rows_by_id(self, id):
self._fetchone_entity('select id, hash from entities where id=?',
(id,))
# File processing

View File

@ -14,7 +14,7 @@ from .errors import (
PhoneMigrateError, NetworkMigrateError, UserMigrateError
)
from .network import authenticator, MtProtoSender, Connection, ConnectionMode
from .session import Session
from .sessions import Session, SQLiteSession
from .tl import TLObject
from .tl.all_tlobjects import LAYER
from .tl.functions import (
@ -81,10 +81,10 @@ class TelegramBareClient:
"Refer to telethon.rtfd.io for more information.")
self._use_ipv6 = use_ipv6
# Determine what session object we have
if isinstance(session, str) or session is None:
session = Session(session)
session = SQLiteSession(session)
elif not isinstance(session, Session):
raise TypeError(
'The given session must be a str or a Session instance.'
@ -361,7 +361,7 @@ class TelegramBareClient:
#
# Construct this session with the connection parameters
# (system version, device model...) from the current one.
session = Session(self.session)
session = self.session.clone()
session.set_dc(dc.id, dc.ip_address, dc.port)
self._exported_sessions[dc_id] = session
@ -387,7 +387,7 @@ class TelegramBareClient:
session = self._exported_sessions.get(cdn_redirect.dc_id)
if not session:
dc = self._get_dc(cdn_redirect.dc_id, cdn=True)
session = Session(self.session)
session = self.session.clone()
session.set_dc(dc.id, dc.ip_address, dc.port)
self._exported_sessions[cdn_redirect.dc_id] = session