diff --git a/telethon/sessions/sqlite.py b/telethon/sessions/sqlite.py index c15ab1b2..cc4bd879 100644 --- a/telethon/sessions/sqlite.py +++ b/telethon/sessions/sqlite.py @@ -179,14 +179,11 @@ class SQLiteSession(MemorySession): self._update_session_table() # Fetch the auth_key corresponding to this data center - c = self._cursor() - c.execute('select auth_key from sessions') - tuple_ = c.fetchone() - if tuple_ and tuple_[0]: - self._auth_key = AuthKey(data=tuple_[0]) + row = self._execute('select auth_key from sessions') + if row and row[0]: + self._auth_key = AuthKey(data=row[0]) else: self._auth_key = None - c.close() @MemorySession.auth_key.setter def auth_key(self, value): @@ -210,21 +207,17 @@ class SQLiteSession(MemorySession): c.close() def get_update_state(self, entity_id): - c = self._cursor() - row = c.execute('select pts, qts, date, seq from update_state ' - 'where id = ?', (entity_id,)).fetchone() - c.close() + row = self._execute('select pts, qts, date, seq from update_state ' + 'where id = ?', entity_id) if row: pts, qts, date, seq = row date = datetime.datetime.utcfromtimestamp(date) return types.updates.State(pts, qts, date, seq, unread_count=0) def set_update_state(self, entity_id, state): - c = self._cursor() - c.execute('insert or replace into update_state values (?,?,?,?,?)', - (entity_id, state.pts, state.qts, - state.date.timestamp(), state.seq)) - c.close() + self._execute('insert or replace into update_state values (?,?,?,?,?)', + entity_id, state.pts, state.qts, + state.date.timestamp(), state.seq) def save(self): """Saves the current session object as session_user_id.session""" @@ -239,6 +232,17 @@ class SQLiteSession(MemorySession): check_same_thread=False) return self._conn.cursor() + def _execute(self, stmt, *values): + """ + Gets a cursor, executes `stmt` and closes the cursor, + fetching one row afterwards and returning its result. + """ + c = self._cursor() + try: + return c.execute(stmt, values).fetchone() + finally: + c.close() + def close(self): """Closes the connection unless we're working in-memory""" if self.filename != ':memory:': @@ -281,67 +285,55 @@ class SQLiteSession(MemorySession): return c = self._cursor() - c.executemany( - 'insert or replace into entities values (?,?,?,?,?)', rows - ) - c.close() - - def _fetchone_entity(self, query, args): - c = self._cursor() - c.execute(query, args) - t = c.fetchone() - c.close() - return t + try: + c.executemany( + 'insert or replace into entities values (?,?,?,?,?)', rows) + finally: + c.close() def get_entity_rows_by_phone(self, phone): - return self._fetchone_entity( - 'select id, hash from entities where phone=?', (phone,)) + return self._execute( + 'select id, hash from entities where phone = ?', phone) def get_entity_rows_by_username(self, username): - return self._fetchone_entity( - 'select id, hash from entities where username=?', (username,)) + return self._execute( + 'select id, hash from entities where username = ?', username) def get_entity_rows_by_name(self, name): - return self._fetchone_entity( - 'select id, hash from entities where name=?', (name,)) + return self._execute( + 'select id, hash from entities where name = ?', name) def get_entity_rows_by_id(self, id, exact=True): if exact: - return self._fetchone_entity( - 'select id, hash from entities where id=?', (id,)) + return self._execute( + 'select id, hash from entities where id = ?', id) else: - ids = ( + return self._execute( + 'select id, hash from entities where id in (?,?,?)', utils.get_peer_id(PeerUser(id)), utils.get_peer_id(PeerChat(id)), utils.get_peer_id(PeerChannel(id)) ) - return self._fetchone_entity( - 'select id, hash from entities where id in (?,?,?)', ids - ) # File processing def get_file(self, md5_digest, file_size, cls): - c = self._cursor() - tuple_ = c.execute( + row = self._execute( 'select id, hash from sent_files ' 'where md5_digest = ? and file_size = ? and type = ?', - (md5_digest, file_size, _SentFileType.from_type(cls).value) - ).fetchone() - c.close() - if tuple_: + md5_digest, file_size, _SentFileType.from_type(cls).value + ) + if row: # Both allowed classes have (id, access_hash) as parameters - return cls(tuple_[0], tuple_[1]) + return cls(row[0], row[1]) def cache_file(self, md5_digest, file_size, instance): if not isinstance(instance, (InputDocument, InputPhoto)): raise TypeError('Cannot cache %s instance' % type(instance)) - c = self._cursor() - c.execute( - 'insert or replace into sent_files values (?,?,?,?,?)', ( - md5_digest, file_size, - _SentFileType.from_type(type(instance)).value, - instance.id, instance.access_hash - )) - c.close() + self._execute( + 'insert or replace into sent_files values (?,?,?,?,?)', + md5_digest, file_size, + _SentFileType.from_type(type(instance)).value, + instance.id, instance.access_hash + )