Fix SQLAlchemy implementation

This commit is contained in:
Tulir Asokan 2018-03-02 20:14:11 +02:00
parent dc2229fdba
commit c1a8896faa
2 changed files with 114 additions and 84 deletions

View File

@ -1,3 +1,4 @@
from .abstract import Session
from .memory import MemorySession
from .sqlite import SQLiteSession
from .sqlalchemy import AlchemySessionContainer, AlchemySession

View File

@ -1,32 +1,63 @@
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy import Column, String, Integer, Blob, orm
from sqlalchemy import Column, String, Integer, BLOB, orm
import sqlalchemy as sql
from ..crypto import AuthKey
from ..tl.types import InputPhoto, InputDocument
from .memory import MemorySession, _SentFileType
Base = declarative_base()
LATEST_VERSION = 1
class DBVersion(Base):
__tablename__ = "version"
class AlchemySessionContainer:
def __init__(self, engine=None, session=None, table_prefix="",
table_base=None, manage_tables=True):
if isinstance(engine, str):
engine = sql.create_engine(engine)
self.db_engine = engine
if not session:
db_factory = orm.sessionmaker(bind=self.db_engine)
self.db = orm.scoping.scoped_session(db_factory)
else:
self.db = session
table_base = table_base or declarative_base()
(self.Version, self.Session, self.Entity,
self.SentFile) = self.create_table_classes(self.db, table_prefix,
table_base)
if manage_tables:
table_base.metadata.bind = self.db_engine
if not self.db_engine.dialect.has_table(self.db_engine,
self.Version.__tablename__):
table_base.metadata.create_all()
self.db.add(self.Version(version=LATEST_VERSION))
self.db.commit()
else:
self.check_and_upgrade_database()
@staticmethod
def create_table_classes(db, prefix, Base):
class Version(Base):
query = db.query_property()
__tablename__ = "{prefix}version".format(prefix=prefix)
version = Column(Integer, primary_key=True)
class DBSession(Base):
__tablename__ = "sessions"
class Session(Base):
query = db.query_property()
__tablename__ = "{prefix}sessions".format(prefix=prefix)
session_id = Column(String, primary_key=True)
dc_id = Column(Integer, primary_key=True)
server_address = Column(String)
port = Column(Integer)
auth_key = Column(Blob)
auth_key = Column(BLOB)
class DBEntity(Base):
__tablename__ = "entities"
class Entity(Base):
query = db.query_property()
__tablename__ = "{prefix}entities".format(prefix=prefix)
session_id = Column(String, primary_key=True)
id = Column(Integer, primary_key=True)
@ -35,51 +66,30 @@ class DBEntity(Base):
phone = Column(Integer)
name = Column(String)
class DBSentFile(Base):
__tablename__ = "sent_files"
class SentFile(Base):
query = db.query_property()
__tablename__ = "{prefix}sent_files".format(prefix=prefix)
session_id = Column(String, primary_key=True)
md5_digest = Column(Blob, primary_key=True)
md5_digest = Column(BLOB, primary_key=True)
file_size = Column(Integer, primary_key=True)
type = Column(Integer, primary_key=True)
id = Column(Integer)
hash = Column(Integer)
class AlchemySessionContainer:
def __init__(self, database):
if isinstance(database, str):
database = sql.create_engine(database)
self.db_engine = database
db_factory = orm.sessionmaker(bind=self.db_engine)
self.db = orm.scoping.scoped_session(db_factory)
if not self.db_engine.dialect.has_table(self.db_engine,
DBVersion.__tablename__):
Base.metadata.create_all(bind=self.db_engine)
self.db.add(DBVersion(version=LATEST_VERSION))
self.db.commit()
else:
self.check_and_upgrade_database()
DBVersion.query = self.db.query_property()
DBSession.query = self.db.query_property()
DBEntity.query = self.db.query_property()
DBSentFile.query = self.db.query_property()
return Version, Session, Entity, SentFile
def check_and_upgrade_database(self):
row = DBVersion.query.get()
version = row.version if row else 1
row = self.Version.query.all()
version = row[0].version if row else 1
if version == LATEST_VERSION:
return
DBVersion.query.delete()
self.Version.query.delete()
# Implement table schema updates here and increase version
self.db.add(DBVersion(version=version))
self.db.add(self.Version(version=version))
self.db.commit()
def new_session(self, session_id):
@ -97,7 +107,20 @@ class AlchemySession(MemorySession):
super().__init__()
self.container = container
self.db = container.db
self.Version, self.Session, self.Entity, self.SentFile = (
container.Version, container.Session, container.Entity,
container.SentFile)
self.session_id = session_id
self.load_session()
def load_session(self):
sessions = self._db_query(self.Session).all()
session = sessions[0] if sessions else None
if session:
self._dc_id = session.dc_id
self._server_address = session.server_address
self._port = session.port
self._auth_key = AuthKey(data=session.auth_key)
def clone(self, to_instance=None):
cloned = to_instance or self.__class__(self.container, self.session_id)
@ -107,16 +130,17 @@ class AlchemySession(MemorySession):
super().set_dc(dc_id, server_address, port)
def _update_session_table(self):
self.db.query(DBSession).filter(
DBSession.session_id == self.session_id).delete()
new = DBSession(session_id=self.session_id, dc_id=self._dc_id,
server_address=self._server_address, port=self._port,
auth_key=self._auth_key.key if self._auth_key else b'')
self.Session.query.filter(
self.Session.session_id == self.session_id).delete()
new = self.Session(session_id=self.session_id, dc_id=self._dc_id,
server_address=self._server_address,
port=self._port,
auth_key=(self._auth_key.key
if self._auth_key else b''))
self.db.merge(new)
def _db_query(self, dbclass, *args):
return self.db.query(dbclass).filter(
dbclass.session_id == self.session_id,
return dbclass.query.filter(dbclass.session_id == self.session_id,
*args)
def save(self):
@ -127,12 +151,12 @@ class AlchemySession(MemorySession):
pass
def delete(self):
self._db_query(DBSession).delete()
self._db_query(DBEntity).delete()
self._db_query(DBSentFile).delete()
self._db_query(self.Session).delete()
self._db_query(self.Entity).delete()
self._db_query(self.SentFile).delete()
def _entity_values_to_row(self, id, hash, username, phone, name):
return DBEntity(session_id=self.session_id, id=id, hash=hash,
return self.Entity(session_id=self.session_id, id=id, hash=hash,
username=username, phone=phone, name=name)
def process_entities(self, tlo):
@ -140,29 +164,34 @@ class AlchemySession(MemorySession):
if not rows:
return
self.db.add_all(rows)
for row in rows:
self.db.merge(row)
self.save()
def get_entity_rows_by_phone(self, key):
row = self._db_query(DBEntity, DBEntity.phone == key).one_or_none()
row = self._db_query(self.Entity,
self.Entity.phone == key).one_or_none()
return row.id, row.hash if row else None
def get_entity_rows_by_username(self, key):
row = self._db_query(DBEntity, DBEntity.username == key).one_or_none()
row = self._db_query(self.Entity,
self.Entity.username == key).one_or_none()
return row.id, row.hash if row else None
def get_entity_rows_by_name(self, key):
row = self._db_query(DBEntity, DBEntity.name == key).one_or_none()
row = self._db_query(self.Entity,
self.Entity.name == key).one_or_none()
return row.id, row.hash if row else None
def get_entity_rows_by_id(self, key):
row = self._db_query(DBEntity, DBEntity.id == key).one_or_none()
row = self._db_query(self.Entity, self.Entity.id == key).one_or_none()
return row.id, row.hash if row else None
def get_file(self, md5_digest, file_size, cls):
row = self._db_query(DBSentFile, DBSentFile.md5_digest == md5_digest,
DBSentFile.file_size == file_size,
DBSentFile.type == _SentFileType.from_type(
row = self._db_query(self.SentFile,
self.SentFile.md5_digest == md5_digest,
self.SentFile.file_size == file_size,
self.SentFile.type == _SentFileType.from_type(
cls).value).one_or_none()
return row.id, row.hash if row else None
@ -171,7 +200,7 @@ class AlchemySession(MemorySession):
raise TypeError('Cannot cache %s instance' % type(instance))
self.db.merge(
DBSentFile(session_id=self.session_id, md5_digest=md5_digest,
self.SentFile(session_id=self.session_id, md5_digest=md5_digest,
type=_SentFileType.from_type(type(instance)).value,
id=instance.id, hash=instance.access_hash))
self.save()