Add SQLAlchemy-based session

This commit is contained in:
Tulir Asokan 2018-03-02 13:22:30 +02:00
parent 118d9b10e8
commit 07c2fc50ec
3 changed files with 180 additions and 1 deletions

View File

@ -1,3 +1,4 @@
cryptg
pysocks
hachoir3
sqlalchemy

View File

@ -151,7 +151,8 @@ def main():
]),
install_requires=['pyaes', 'rsa'],
extras_require={
'cryptg': ['cryptg']
'cryptg': ['cryptg'],
'sqlalchemy': ['sqlalchemy']
}
)

View File

@ -0,0 +1,177 @@
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy import Column, String, Integer, Blob, orm
import sqlalchemy as sql
from ..tl.types import InputPhoto, InputDocument
from .memory import MemorySession, _SentFileType
Base = declarative_base()
LATEST_VERSION = 1
class DBVersion(Base):
__tablename__ = "version"
version = Column(Integer, primary_key=True)
class DBSession(Base):
__tablename__ = "sessions"
session_id = Column(String, primary_key=True)
dc_id = Column(Integer, primary_key=True)
server_address = Column(String)
port = Column(Integer)
auth_key = Column(Blob)
class DBEntity(Base):
__tablename__ = "entities"
session_id = Column(String, primary_key=True)
id = Column(Integer, primary_key=True)
hash = Column(Integer, nullable=False)
username = Column(String)
phone = Column(Integer)
name = Column(String)
class DBSentFile(Base):
__tablename__ = "sent_files"
session_id = Column(String, primary_key=True)
md5_digest = Column(Blob, primary_key=True)
file_size = Column(Integer, primary_key=True)
type = Column(Integer, primary_key=True)
id = Column(Integer)
hash = Column(Integer)
class AlchemySessionContainer:
def __init__(self, database):
if not isinstance(database, sql.Engine):
database = sql.create_engine(database)
self.db_engine = database
db_factory = orm.sessionmaker(bind=self.db_engine)
self.db = orm.scoping.scoped_session(db_factory)
if not self.db_engine.dialect.has_table(self.db_engine,
DBVersion.__tablename__):
Base.metadata.create_all(bind=self.db_engine)
self.db.add(DBVersion(version=LATEST_VERSION))
self.db.commit()
else:
self.check_and_upgrade_database()
DBVersion.query = self.db.query_property()
DBSession.query = self.db.query_property()
DBEntity.query = self.db.query_property()
DBSentFile.query = self.db.query_property()
def check_and_upgrade_database(self):
row = DBVersion.query.get()
version = row.version if row else 1
if version == LATEST_VERSION:
return
DBVersion.query.delete()
# Implement table schema updates here and increase version
self.db.add(DBVersion(version=version))
self.db.commit()
def new_session(self, session_id):
return AlchemySession(self, session_id)
def list_sessions(self):
return
def save(self):
self.db.commit()
class AlchemySession(MemorySession):
def __init__(self, container, session_id):
super().__init__()
self.container = container
self.db = container.db
self.session_id = session_id
def clone(self, to_instance=None):
cloned = to_instance or self.__class__(self.container, self.session_id)
return super().clone(cloned)
def set_dc(self, dc_id, server_address, port):
super().set_dc(dc_id, server_address, port)
def _update_session_table(self):
self.db.query(DBSession).filter(
DBSession.session_id == self.session_id).delete()
new = DBSession(session_id=self.session_id, dc_id=self._dc_id,
server_address=self._server_address, port=self._port,
auth_key=self._auth_key.key if self._auth_key else b'')
self.db.merge(new)
def _db_query(self, dbclass, *args):
return self.db.query(dbclass).filter(
dbclass.session_id == self.session_id,
*args)
def save(self):
self.container.save()
def close(self):
# Nothing to do here, connection is managed by AlchemySessionContainer.
pass
def delete(self):
self._db_query(DBSession).delete()
self._db_query(DBEntity).delete()
self._db_query(DBSentFile).delete()
def _entity_values_to_row(self, id, hash, username, phone, name):
return DBEntity(session_id=self.session_id, id=id, hash=hash,
username=username, phone=phone, name=name)
def process_entities(self, tlo):
rows = self._entities_to_rows(tlo)
if not rows:
return
self.db.add_all(rows)
self.save()
def get_entity_rows_by_phone(self, key):
row = self._db_query(DBEntity, DBEntity.phone == key).one_or_none()
return row.id, row.hash if row else None
def get_entity_rows_by_username(self, key):
row = self._db_query(DBEntity, DBEntity.username == key).one_or_none()
return row.id, row.hash if row else None
def get_entity_rows_by_name(self, key):
row = self._db_query(DBEntity, DBEntity.name == key).one_or_none()
return row.id, row.hash if row else None
def get_entity_rows_by_id(self, key):
row = self._db_query(DBEntity, DBEntity.id == key).one_or_none()
return row.id, row.hash if row else None
def get_file(self, md5_digest, file_size, cls):
row = self._db_query(DBSentFile, DBSentFile.md5_digest == md5_digest,
DBSentFile.file_size == file_size,
DBSentFile.type == _SentFileType.from_type(
cls).value).one_or_none()
return row.id, row.hash if row else None
def cache_file(self, md5_digest, file_size, instance):
if not isinstance(instance, (InputDocument, InputPhoto)):
raise TypeError('Cannot cache %s instance' % type(instance))
self.db.merge(
DBSentFile(session_id=self.session_id, md5_digest=md5_digest,
type=_SentFileType.from_type(type(instance)).value,
id=instance.id, hash=instance.access_hash))
self.save()