Reuse more code using sqlite's cursor

This commit is contained in:
Lonami Exo 2018-06-25 20:11:48 +02:00
parent 313bead615
commit 551b0044ce

View File

@ -179,14 +179,11 @@ class SQLiteSession(MemorySession):
self._update_session_table() self._update_session_table()
# Fetch the auth_key corresponding to this data center # Fetch the auth_key corresponding to this data center
c = self._cursor() row = self._execute('select auth_key from sessions')
c.execute('select auth_key from sessions') if row and row[0]:
tuple_ = c.fetchone() self._auth_key = AuthKey(data=row[0])
if tuple_ and tuple_[0]:
self._auth_key = AuthKey(data=tuple_[0])
else: else:
self._auth_key = None self._auth_key = None
c.close()
@MemorySession.auth_key.setter @MemorySession.auth_key.setter
def auth_key(self, value): def auth_key(self, value):
@ -210,21 +207,17 @@ class SQLiteSession(MemorySession):
c.close() c.close()
def get_update_state(self, entity_id): def get_update_state(self, entity_id):
c = self._cursor() row = self._execute('select pts, qts, date, seq from update_state '
row = c.execute('select pts, qts, date, seq from update_state ' 'where id = ?', entity_id)
'where id = ?', (entity_id,)).fetchone()
c.close()
if row: if row:
pts, qts, date, seq = row pts, qts, date, seq = row
date = datetime.datetime.utcfromtimestamp(date) date = datetime.datetime.utcfromtimestamp(date)
return types.updates.State(pts, qts, date, seq, unread_count=0) return types.updates.State(pts, qts, date, seq, unread_count=0)
def set_update_state(self, entity_id, state): def set_update_state(self, entity_id, state):
c = self._cursor() self._execute('insert or replace into update_state values (?,?,?,?,?)',
c.execute('insert or replace into update_state values (?,?,?,?,?)', entity_id, state.pts, state.qts,
(entity_id, state.pts, state.qts, state.date.timestamp(), state.seq)
state.date.timestamp(), state.seq))
c.close()
def save(self): def save(self):
"""Saves the current session object as session_user_id.session""" """Saves the current session object as session_user_id.session"""
@ -239,6 +232,17 @@ class SQLiteSession(MemorySession):
check_same_thread=False) check_same_thread=False)
return self._conn.cursor() return self._conn.cursor()
def _execute(self, stmt, *values):
"""
Gets a cursor, executes `stmt` and closes the cursor,
fetching one row afterwards and returning its result.
"""
c = self._cursor()
try:
return c.execute(stmt, values).fetchone()
finally:
c.close()
def close(self): def close(self):
"""Closes the connection unless we're working in-memory""" """Closes the connection unless we're working in-memory"""
if self.filename != ':memory:': if self.filename != ':memory:':
@ -281,67 +285,55 @@ class SQLiteSession(MemorySession):
return return
c = self._cursor() c = self._cursor()
try:
c.executemany( c.executemany(
'insert or replace into entities values (?,?,?,?,?)', rows 'insert or replace into entities values (?,?,?,?,?)', rows)
) finally:
c.close() c.close()
def _fetchone_entity(self, query, args):
c = self._cursor()
c.execute(query, args)
t = c.fetchone()
c.close()
return t
def get_entity_rows_by_phone(self, phone): def get_entity_rows_by_phone(self, phone):
return self._fetchone_entity( return self._execute(
'select id, hash from entities where phone=?', (phone,)) 'select id, hash from entities where phone = ?', phone)
def get_entity_rows_by_username(self, username): def get_entity_rows_by_username(self, username):
return self._fetchone_entity( return self._execute(
'select id, hash from entities where username=?', (username,)) 'select id, hash from entities where username = ?', username)
def get_entity_rows_by_name(self, name): def get_entity_rows_by_name(self, name):
return self._fetchone_entity( return self._execute(
'select id, hash from entities where name=?', (name,)) 'select id, hash from entities where name = ?', name)
def get_entity_rows_by_id(self, id, exact=True): def get_entity_rows_by_id(self, id, exact=True):
if exact: if exact:
return self._fetchone_entity( return self._execute(
'select id, hash from entities where id=?', (id,)) 'select id, hash from entities where id = ?', id)
else: else:
ids = ( return self._execute(
'select id, hash from entities where id in (?,?,?)',
utils.get_peer_id(PeerUser(id)), utils.get_peer_id(PeerUser(id)),
utils.get_peer_id(PeerChat(id)), utils.get_peer_id(PeerChat(id)),
utils.get_peer_id(PeerChannel(id)) utils.get_peer_id(PeerChannel(id))
) )
return self._fetchone_entity(
'select id, hash from entities where id in (?,?,?)', ids
)
# File processing # File processing
def get_file(self, md5_digest, file_size, cls): def get_file(self, md5_digest, file_size, cls):
c = self._cursor() row = self._execute(
tuple_ = c.execute(
'select id, hash from sent_files ' 'select id, hash from sent_files '
'where md5_digest = ? and file_size = ? and type = ?', 'where md5_digest = ? and file_size = ? and type = ?',
(md5_digest, file_size, _SentFileType.from_type(cls).value) md5_digest, file_size, _SentFileType.from_type(cls).value
).fetchone() )
c.close() if row:
if tuple_:
# Both allowed classes have (id, access_hash) as parameters # Both allowed classes have (id, access_hash) as parameters
return cls(tuple_[0], tuple_[1]) return cls(row[0], row[1])
def cache_file(self, md5_digest, file_size, instance): def cache_file(self, md5_digest, file_size, instance):
if not isinstance(instance, (InputDocument, InputPhoto)): if not isinstance(instance, (InputDocument, InputPhoto)):
raise TypeError('Cannot cache %s instance' % type(instance)) raise TypeError('Cannot cache %s instance' % type(instance))
c = self._cursor() self._execute(
c.execute( 'insert or replace into sent_files values (?,?,?,?,?)',
'insert or replace into sent_files values (?,?,?,?,?)', (
md5_digest, file_size, md5_digest, file_size,
_SentFileType.from_type(type(instance)).value, _SentFileType.from_type(type(instance)).value,
instance.id, instance.access_hash instance.id, instance.access_hash
)) )
c.close()