Properly close the sqlite3 connection (#560)

This commit is contained in:
Lonami Exo 2018-01-26 09:59:49 +01:00
parent 5c2dfc17a8
commit 43a3f40527
2 changed files with 30 additions and 13 deletions

View File

@ -7,7 +7,7 @@ import time
from base64 import b64decode from base64 import b64decode
from enum import Enum from enum import Enum
from os.path import isfile as file_exists from os.path import isfile as file_exists
from threading import Lock from threading import Lock, RLock
from . import utils from . import utils
from .crypto import AuthKey from .crypto import AuthKey
@ -89,7 +89,7 @@ class Session:
# Cross-thread safety # Cross-thread safety
self._seq_no_lock = Lock() self._seq_no_lock = Lock()
self._msg_id_lock = Lock() self._msg_id_lock = Lock()
self._db_lock = Lock() self._db_lock = RLock()
# These values will be saved # These values will be saved
self._dc_id = 0 self._dc_id = 0
@ -100,8 +100,8 @@ class Session:
# Migrating from .json -> SQL # Migrating from .json -> SQL
entities = self._check_migrate_json() entities = self._check_migrate_json()
self._conn = sqlite3.connect(self.filename, check_same_thread=False) self._conn = None
c = self._conn.cursor() c = self._cursor()
c.execute("select name from sqlite_master " c.execute("select name from sqlite_master "
"where type='table' and name='version'") "where type='table' and name='version'")
if c.fetchone(): if c.fetchone():
@ -186,7 +186,7 @@ class Session:
return [] # No entities return [] # No entities
def _upgrade_database(self, old): def _upgrade_database(self, old):
c = self._conn.cursor() c = self._cursor()
# old == 1 doesn't have the old sent_files so no need to drop # old == 1 doesn't have the old sent_files so no need to drop
if old == 2: if old == 2:
# Old cache from old sent_files lasts then a day anyway, drop # Old cache from old sent_files lasts then a day anyway, drop
@ -223,7 +223,7 @@ class Session:
self._update_session_table() self._update_session_table()
# Fetch the auth_key corresponding to this data center # Fetch the auth_key corresponding to this data center
c = self._conn.cursor() c = self._cursor()
c.execute('select auth_key from sessions') c.execute('select auth_key from sessions')
tuple_ = c.fetchone() tuple_ = c.fetchone()
if tuple_: if tuple_:
@ -251,7 +251,7 @@ class Session:
def _update_session_table(self): def _update_session_table(self):
with self._db_lock: with self._db_lock:
c = self._conn.cursor() c = self._cursor()
# While we can save multiple rows into the sessions table # While we can save multiple rows into the sessions table
# currently we only want to keep ONE as the tables don't # currently we only want to keep ONE as the tables don't
# tell us which auth_key's are usable and will work. Needs # tell us which auth_key's are usable and will work. Needs
@ -271,6 +271,22 @@ class Session:
with self._db_lock: with self._db_lock:
self._conn.commit() self._conn.commit()
def _cursor(self):
"""Asserts that the connection is open and returns a cursor"""
with self._db_lock:
if self._conn is None:
self._conn = sqlite3.connect(self.filename,
check_same_thread=False)
return self._conn.cursor()
def close(self):
"""Closes the connection unless we're working in-memory"""
if self.filename != ':memory:':
with self._db_lock:
if self._conn is not None:
self._conn.close()
self._conn = None
def delete(self): def delete(self):
"""Deletes the current session file""" """Deletes the current session file"""
if self.filename == ':memory:': if self.filename == ':memory:':
@ -385,10 +401,10 @@ class Session:
return return
with self._db_lock: with self._db_lock:
self._conn.executemany( self._cursor().executemany(
'insert or replace into entities values (?,?,?,?,?)', rows 'insert or replace into entities values (?,?,?,?,?)', rows
) )
self.save() self.save()
def get_input_entity(self, key): def get_input_entity(self, key):
"""Parses the given string, integer or TLObject key into a """Parses the given string, integer or TLObject key into a
@ -413,7 +429,7 @@ class Session:
if isinstance(key, TLObject): if isinstance(key, TLObject):
key = utils.get_peer_id(key) key = utils.get_peer_id(key)
c = self._conn.cursor() c = self._cursor()
if isinstance(key, str): if isinstance(key, str):
phone = utils.parse_phone(key) phone = utils.parse_phone(key)
if phone: if phone:
@ -444,7 +460,7 @@ class Session:
# File processing # File processing
def get_file(self, md5_digest, file_size, cls): def get_file(self, md5_digest, file_size, cls):
tuple_ = self._conn.execute( tuple_ = self._cursor().execute(
'select id, hash from sent_files ' 'select id, hash from sent_files '
'where md5_digest = ? and file_size = ? and type = ?', 'where md5_digest = ? and file_size = ? and type = ?',
(md5_digest, file_size, _SentFileType.from_type(cls).value) (md5_digest, file_size, _SentFileType.from_type(cls).value)
@ -458,10 +474,10 @@ class Session:
raise TypeError('Cannot cache %s instance' % type(instance)) raise TypeError('Cannot cache %s instance' % type(instance))
with self._db_lock: with self._db_lock:
self._conn.execute( self._cursor().execute(
'insert or replace into sent_files values (?,?,?,?,?)', ( 'insert or replace into sent_files values (?,?,?,?,?)', (
md5_digest, file_size, md5_digest, file_size,
_SentFileType.from_type(type(instance)).value, _SentFileType.from_type(type(instance)).value,
instance.id, instance.access_hash instance.id, instance.access_hash
)) ))
self.save() self.save()

View File

@ -253,6 +253,7 @@ class TelegramBareClient:
# TODO Shall we clear the _exported_sessions, or may be reused? # TODO Shall we clear the _exported_sessions, or may be reused?
self._first_request = True # On reconnect it will be first again self._first_request = True # On reconnect it will be first again
self.session.close()
def __del__(self): def __del__(self):
self.disconnect() self.disconnect()