Add more abstraction

This commit is contained in:
Tulir Asokan 2018-03-02 13:20:11 +02:00
parent d9a73744a4
commit 118d9b10e8
3 changed files with 53 additions and 44 deletions

View File

@ -18,8 +18,8 @@ class Session(ABC):
self._report_errors = True self._report_errors = True
self._flood_sleep_threshold = 60 self._flood_sleep_threshold = 60
def clone(self): def clone(self, to_instance=None):
cloned = self.__class__() cloned = to_instance or self.__class__()
cloned._device_model = self.device_model cloned._device_model = self.device_model
cloned._system_version = self.system_version cloned._system_version = self.system_version
cloned._app_version = self.app_version cloned._app_version = self.app_version
@ -27,6 +27,7 @@ class Session(ABC):
cloned._system_lang_code = self.system_lang_code cloned._system_lang_code = self.system_lang_code
cloned._report_errors = self.report_errors cloned._report_errors = self.report_errors
cloned._flood_sleep_threshold = self.flood_sleep_threshold cloned._flood_sleep_threshold = self.flood_sleep_threshold
return cloned
@abstractmethod @abstractmethod
def set_dc(self, dc_id, server_address, port): def set_dc(self, dc_id, server_address, port):

View File

@ -111,8 +111,41 @@ class MemorySession(Session):
def list_sessions(cls): def list_sessions(cls):
raise NotImplementedError raise NotImplementedError
@staticmethod def _entity_values_to_row(self, id, hash, username, phone, name):
def _entities_to_rows(tlo): return id, hash, username, phone, name
def _entity_to_row(self, e):
if not isinstance(e, TLObject):
return
try:
p = utils.get_input_peer(e, allow_self=False)
marked_id = utils.get_peer_id(p)
except ValueError:
return
if isinstance(p, (InputPeerUser, InputPeerChannel)):
if not p.access_hash:
# Some users and channels seem to be returned without
# an 'access_hash', meaning Telegram doesn't want you
# to access them. This is the reason behind ensuring
# that the 'access_hash' is non-zero. See issue #354.
# Note that this checks for zero or None, see #392.
return
else:
p_hash = p.access_hash
elif isinstance(p, InputPeerChat):
p_hash = 0
else:
return
username = getattr(e, 'username', None) or None
if username is not None:
username = username.lower()
phone = getattr(e, 'phone', None)
name = utils.get_display_name(e) or None
return self._entity_values_to_row(marked_id, p_hash, username, phone, name)
def _entities_to_rows(self, tlo):
if not isinstance(tlo, TLObject) and utils.is_list_like(tlo): if not isinstance(tlo, TLObject) and utils.is_list_like(tlo):
# This may be a list of users already for instance # This may be a list of users already for instance
entities = tlo entities = tlo
@ -127,35 +160,9 @@ class MemorySession(Session):
rows = [] # Rows to add (id, hash, username, phone, name) rows = [] # Rows to add (id, hash, username, phone, name)
for e in entities: for e in entities:
if not isinstance(e, TLObject): row = self._entity_to_row(e)
continue if row:
try: rows.append(row)
p = utils.get_input_peer(e, allow_self=False)
marked_id = utils.get_peer_id(p)
except ValueError:
continue
if isinstance(p, (InputPeerUser, InputPeerChannel)):
if not p.access_hash:
# Some users and channels seem to be returned without
# an 'access_hash', meaning Telegram doesn't want you
# to access them. This is the reason behind ensuring
# that the 'access_hash' is non-zero. See issue #354.
# Note that this checks for zero or None, see #392.
continue
else:
p_hash = p.access_hash
elif isinstance(p, InputPeerChat):
p_hash = 0
else:
continue
username = getattr(e, 'username', None) or None
if username is not None:
username = username.lower()
phone = getattr(e, 'phone', None)
name = utils.get_display_name(e) or None
rows.append((marked_id, p_hash, username, phone, name))
return rows return rows
def process_entities(self, tlo): def process_entities(self, tlo):

View File

@ -120,8 +120,8 @@ class SQLiteSession(MemorySession):
c.close() c.close()
self.save() self.save()
def clone(self): def clone(self, to_instance=None):
cloned = super().clone() cloned = super().clone(to_instance)
cloned.save_entities = self.save_entities cloned.save_entities = self.save_entities
return cloned return cloned
@ -180,9 +180,7 @@ class SQLiteSession(MemorySession):
# Data from sessions should be kept as properties # Data from sessions should be kept as properties
# not to fetch the database every time we need it # not to fetch the database every time we need it
def set_dc(self, dc_id, server_address, port): def set_dc(self, dc_id, server_address, port):
self._dc_id = dc_id super().set_dc(dc_id, server_address, port)
self._server_address = server_address
self._port = port
self._update_session_table() self._update_session_table()
# Fetch the auth_key corresponding to this data center # Fetch the auth_key corresponding to this data center
@ -287,15 +285,18 @@ class SQLiteSession(MemorySession):
'select id, hash from entities where phone=?', (phone,)) 'select id, hash from entities where phone=?', (phone,))
def get_entity_rows_by_username(self, username): def get_entity_rows_by_username(self, username):
self._fetchone_entity('select id, hash from entities where username=?', return self._fetchone_entity(
'select id, hash from entities where username=?',
(username,)) (username,))
def get_entity_rows_by_name(self, name): def get_entity_rows_by_name(self, name):
self._fetchone_entity('select id, hash from entities where name=?', return self._fetchone_entity(
'select id, hash from entities where name=?',
(name,)) (name,))
def get_entity_rows_by_id(self, id): def get_entity_rows_by_id(self, id):
self._fetchone_entity('select id, hash from entities where id=?', return self._fetchone_entity(
'select id, hash from entities where id=?',
(id,)) (id,))
# File processing # File processing