Fix SQLAlchemy implementation

This commit is contained in:
Tulir Asokan 2018-03-02 20:14:11 +02:00
parent dc2229fdba
commit c1a8896faa
2 changed files with 114 additions and 84 deletions

View File

@ -1,3 +1,4 @@
from .abstract import Session from .abstract import Session
from .memory import MemorySession from .memory import MemorySession
from .sqlite import SQLiteSession from .sqlite import SQLiteSession
from .sqlalchemy import AlchemySessionContainer, AlchemySession

View File

@ -1,85 +1,95 @@
from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy import Column, String, Integer, Blob, orm from sqlalchemy import Column, String, Integer, BLOB, orm
import sqlalchemy as sql import sqlalchemy as sql
from ..crypto import AuthKey
from ..tl.types import InputPhoto, InputDocument from ..tl.types import InputPhoto, InputDocument
from .memory import MemorySession, _SentFileType from .memory import MemorySession, _SentFileType
Base = declarative_base()
LATEST_VERSION = 1 LATEST_VERSION = 1
class DBVersion(Base):
__tablename__ = "version"
version = Column(Integer, primary_key=True)
class DBSession(Base):
__tablename__ = "sessions"
session_id = Column(String, primary_key=True)
dc_id = Column(Integer, primary_key=True)
server_address = Column(String)
port = Column(Integer)
auth_key = Column(Blob)
class DBEntity(Base):
__tablename__ = "entities"
session_id = Column(String, primary_key=True)
id = Column(Integer, primary_key=True)
hash = Column(Integer, nullable=False)
username = Column(String)
phone = Column(Integer)
name = Column(String)
class DBSentFile(Base):
__tablename__ = "sent_files"
session_id = Column(String, primary_key=True)
md5_digest = Column(Blob, primary_key=True)
file_size = Column(Integer, primary_key=True)
type = Column(Integer, primary_key=True)
id = Column(Integer)
hash = Column(Integer)
class AlchemySessionContainer: class AlchemySessionContainer:
def __init__(self, database): def __init__(self, engine=None, session=None, table_prefix="",
if isinstance(database, str): table_base=None, manage_tables=True):
database = sql.create_engine(database) if isinstance(engine, str):
engine = sql.create_engine(engine)
self.db_engine = database self.db_engine = engine
db_factory = orm.sessionmaker(bind=self.db_engine) if not session:
self.db = orm.scoping.scoped_session(db_factory) db_factory = orm.sessionmaker(bind=self.db_engine)
self.db = orm.scoping.scoped_session(db_factory)
if not self.db_engine.dialect.has_table(self.db_engine,
DBVersion.__tablename__):
Base.metadata.create_all(bind=self.db_engine)
self.db.add(DBVersion(version=LATEST_VERSION))
self.db.commit()
else: else:
self.check_and_upgrade_database() self.db = session
DBVersion.query = self.db.query_property() table_base = table_base or declarative_base()
DBSession.query = self.db.query_property() (self.Version, self.Session, self.Entity,
DBEntity.query = self.db.query_property() self.SentFile) = self.create_table_classes(self.db, table_prefix,
DBSentFile.query = self.db.query_property() table_base)
if manage_tables:
table_base.metadata.bind = self.db_engine
if not self.db_engine.dialect.has_table(self.db_engine,
self.Version.__tablename__):
table_base.metadata.create_all()
self.db.add(self.Version(version=LATEST_VERSION))
self.db.commit()
else:
self.check_and_upgrade_database()
@staticmethod
def create_table_classes(db, prefix, Base):
class Version(Base):
query = db.query_property()
__tablename__ = "{prefix}version".format(prefix=prefix)
version = Column(Integer, primary_key=True)
class Session(Base):
query = db.query_property()
__tablename__ = "{prefix}sessions".format(prefix=prefix)
session_id = Column(String, primary_key=True)
dc_id = Column(Integer, primary_key=True)
server_address = Column(String)
port = Column(Integer)
auth_key = Column(BLOB)
class Entity(Base):
query = db.query_property()
__tablename__ = "{prefix}entities".format(prefix=prefix)
session_id = Column(String, primary_key=True)
id = Column(Integer, primary_key=True)
hash = Column(Integer, nullable=False)
username = Column(String)
phone = Column(Integer)
name = Column(String)
class SentFile(Base):
query = db.query_property()
__tablename__ = "{prefix}sent_files".format(prefix=prefix)
session_id = Column(String, primary_key=True)
md5_digest = Column(BLOB, primary_key=True)
file_size = Column(Integer, primary_key=True)
type = Column(Integer, primary_key=True)
id = Column(Integer)
hash = Column(Integer)
return Version, Session, Entity, SentFile
def check_and_upgrade_database(self): def check_and_upgrade_database(self):
row = DBVersion.query.get() row = self.Version.query.all()
version = row.version if row else 1 version = row[0].version if row else 1
if version == LATEST_VERSION: if version == LATEST_VERSION:
return return
DBVersion.query.delete() self.Version.query.delete()
# Implement table schema updates here and increase version # Implement table schema updates here and increase version
self.db.add(DBVersion(version=version)) self.db.add(self.Version(version=version))
self.db.commit() self.db.commit()
def new_session(self, session_id): def new_session(self, session_id):
@ -97,7 +107,20 @@ class AlchemySession(MemorySession):
super().__init__() super().__init__()
self.container = container self.container = container
self.db = container.db self.db = container.db
self.Version, self.Session, self.Entity, self.SentFile = (
container.Version, container.Session, container.Entity,
container.SentFile)
self.session_id = session_id self.session_id = session_id
self.load_session()
def load_session(self):
sessions = self._db_query(self.Session).all()
session = sessions[0] if sessions else None
if session:
self._dc_id = session.dc_id
self._server_address = session.server_address
self._port = session.port
self._auth_key = AuthKey(data=session.auth_key)
def clone(self, to_instance=None): def clone(self, to_instance=None):
cloned = to_instance or self.__class__(self.container, self.session_id) cloned = to_instance or self.__class__(self.container, self.session_id)
@ -107,17 +130,18 @@ class AlchemySession(MemorySession):
super().set_dc(dc_id, server_address, port) super().set_dc(dc_id, server_address, port)
def _update_session_table(self): def _update_session_table(self):
self.db.query(DBSession).filter( self.Session.query.filter(
DBSession.session_id == self.session_id).delete() self.Session.session_id == self.session_id).delete()
new = DBSession(session_id=self.session_id, dc_id=self._dc_id, new = self.Session(session_id=self.session_id, dc_id=self._dc_id,
server_address=self._server_address, port=self._port, server_address=self._server_address,
auth_key=self._auth_key.key if self._auth_key else b'') port=self._port,
auth_key=(self._auth_key.key
if self._auth_key else b''))
self.db.merge(new) self.db.merge(new)
def _db_query(self, dbclass, *args): def _db_query(self, dbclass, *args):
return self.db.query(dbclass).filter( return dbclass.query.filter(dbclass.session_id == self.session_id,
dbclass.session_id == self.session_id, *args)
*args)
def save(self): def save(self):
self.container.save() self.container.save()
@ -127,42 +151,47 @@ class AlchemySession(MemorySession):
pass pass
def delete(self): def delete(self):
self._db_query(DBSession).delete() self._db_query(self.Session).delete()
self._db_query(DBEntity).delete() self._db_query(self.Entity).delete()
self._db_query(DBSentFile).delete() self._db_query(self.SentFile).delete()
def _entity_values_to_row(self, id, hash, username, phone, name): def _entity_values_to_row(self, id, hash, username, phone, name):
return DBEntity(session_id=self.session_id, id=id, hash=hash, return self.Entity(session_id=self.session_id, id=id, hash=hash,
username=username, phone=phone, name=name) username=username, phone=phone, name=name)
def process_entities(self, tlo): def process_entities(self, tlo):
rows = self._entities_to_rows(tlo) rows = self._entities_to_rows(tlo)
if not rows: if not rows:
return return
self.db.add_all(rows) for row in rows:
self.db.merge(row)
self.save() self.save()
def get_entity_rows_by_phone(self, key): def get_entity_rows_by_phone(self, key):
row = self._db_query(DBEntity, DBEntity.phone == key).one_or_none() row = self._db_query(self.Entity,
self.Entity.phone == key).one_or_none()
return row.id, row.hash if row else None return row.id, row.hash if row else None
def get_entity_rows_by_username(self, key): def get_entity_rows_by_username(self, key):
row = self._db_query(DBEntity, DBEntity.username == key).one_or_none() row = self._db_query(self.Entity,
self.Entity.username == key).one_or_none()
return row.id, row.hash if row else None return row.id, row.hash if row else None
def get_entity_rows_by_name(self, key): def get_entity_rows_by_name(self, key):
row = self._db_query(DBEntity, DBEntity.name == key).one_or_none() row = self._db_query(self.Entity,
self.Entity.name == key).one_or_none()
return row.id, row.hash if row else None return row.id, row.hash if row else None
def get_entity_rows_by_id(self, key): def get_entity_rows_by_id(self, key):
row = self._db_query(DBEntity, DBEntity.id == key).one_or_none() row = self._db_query(self.Entity, self.Entity.id == key).one_or_none()
return row.id, row.hash if row else None return row.id, row.hash if row else None
def get_file(self, md5_digest, file_size, cls): def get_file(self, md5_digest, file_size, cls):
row = self._db_query(DBSentFile, DBSentFile.md5_digest == md5_digest, row = self._db_query(self.SentFile,
DBSentFile.file_size == file_size, self.SentFile.md5_digest == md5_digest,
DBSentFile.type == _SentFileType.from_type( self.SentFile.file_size == file_size,
self.SentFile.type == _SentFileType.from_type(
cls).value).one_or_none() cls).value).one_or_none()
return row.id, row.hash if row else None return row.id, row.hash if row else None
@ -171,7 +200,7 @@ class AlchemySession(MemorySession):
raise TypeError('Cannot cache %s instance' % type(instance)) raise TypeError('Cannot cache %s instance' % type(instance))
self.db.merge( self.db.merge(
DBSentFile(session_id=self.session_id, md5_digest=md5_digest, self.SentFile(session_id=self.session_id, md5_digest=md5_digest,
type=_SentFileType.from_type(type(instance)).value, type=_SentFileType.from_type(type(instance)).value,
id=instance.id, hash=instance.access_hash)) id=instance.id, hash=instance.access_hash))
self.save() self.save()