Remove SQLAlchemy session

This commit is contained in:
Lonami Exo 2018-03-17 17:38:16 +01:00
parent 6f820cce97
commit 32fd64d655
4 changed files with 1 additions and 240 deletions

View File

@ -1,4 +1,3 @@
cryptg cryptg
pysocks pysocks
hachoir3 hachoir3
sqlalchemy

View File

@ -156,8 +156,7 @@ def main():
install_requires=['pyaes', 'rsa', install_requires=['pyaes', 'rsa',
'typing' if version_info < (3, 5) else ""], 'typing' if version_info < (3, 5) else ""],
extras_require={ extras_require={
'cryptg': ['cryptg'], 'cryptg': ['cryptg']
'sqlalchemy': ['sqlalchemy']
} }
) )

View File

@ -1,4 +1,3 @@
from .abstract import Session from .abstract import Session
from .memory import MemorySession from .memory import MemorySession
from .sqlite import SQLiteSession from .sqlite import SQLiteSession
from .sqlalchemy import AlchemySessionContainer, AlchemySession

View File

@ -1,236 +0,0 @@
try:
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy import Column, String, Integer, LargeBinary, orm
import sqlalchemy as sql
except ImportError:
sql = None
from .memory import MemorySession, _SentFileType
from .. import utils
from ..crypto import AuthKey
from ..tl.types import (
InputPhoto, InputDocument, PeerUser, PeerChat, PeerChannel
)
LATEST_VERSION = 1
class AlchemySessionContainer:
def __init__(self, engine=None, session=None, table_prefix='',
table_base=None, manage_tables=True):
if not sql:
raise ImportError('SQLAlchemy not imported')
if isinstance(engine, str):
engine = sql.create_engine(engine)
self.db_engine = engine
if not session:
db_factory = orm.sessionmaker(bind=self.db_engine)
self.db = orm.scoping.scoped_session(db_factory)
else:
self.db = session
table_base = table_base or declarative_base()
(self.Version, self.Session, self.Entity,
self.SentFile) = self.create_table_classes(self.db, table_prefix,
table_base)
if manage_tables:
table_base.metadata.bind = self.db_engine
if not self.db_engine.dialect.has_table(self.db_engine,
self.Version.__tablename__):
table_base.metadata.create_all()
self.db.add(self.Version(version=LATEST_VERSION))
self.db.commit()
else:
self.check_and_upgrade_database()
@staticmethod
def create_table_classes(db, prefix, Base):
class Version(Base):
query = db.query_property()
__tablename__ = '{prefix}version'.format(prefix=prefix)
version = Column(Integer, primary_key=True)
class Session(Base):
query = db.query_property()
__tablename__ = '{prefix}sessions'.format(prefix=prefix)
session_id = Column(String, primary_key=True)
dc_id = Column(Integer, primary_key=True)
server_address = Column(String)
port = Column(Integer)
auth_key = Column(LargeBinary)
class Entity(Base):
query = db.query_property()
__tablename__ = '{prefix}entities'.format(prefix=prefix)
session_id = Column(String, primary_key=True)
id = Column(Integer, primary_key=True)
hash = Column(Integer, nullable=False)
username = Column(String)
phone = Column(Integer)
name = Column(String)
class SentFile(Base):
query = db.query_property()
__tablename__ = '{prefix}sent_files'.format(prefix=prefix)
session_id = Column(String, primary_key=True)
md5_digest = Column(LargeBinary, primary_key=True)
file_size = Column(Integer, primary_key=True)
type = Column(Integer, primary_key=True)
id = Column(Integer)
hash = Column(Integer)
return Version, Session, Entity, SentFile
def check_and_upgrade_database(self):
row = self.Version.query.all()
version = row[0].version if row else 1
if version == LATEST_VERSION:
return
self.Version.query.delete()
# Implement table schema updates here and increase version
self.db.add(self.Version(version=version))
self.db.commit()
def new_session(self, session_id):
return AlchemySession(self, session_id)
def list_sessions(self):
return
def save(self):
self.db.commit()
class AlchemySession(MemorySession):
def __init__(self, container, session_id):
super().__init__()
self.container = container
self.db = container.db
self.Version, self.Session, self.Entity, self.SentFile = (
container.Version, container.Session, container.Entity,
container.SentFile)
self.session_id = session_id
self._load_session()
def _load_session(self):
sessions = self._db_query(self.Session).all()
session = sessions[0] if sessions else None
if session:
self._dc_id = session.dc_id
self._server_address = session.server_address
self._port = session.port
self._auth_key = AuthKey(data=session.auth_key)
def clone(self, to_instance=None):
return super().clone(MemorySession())
def set_dc(self, dc_id, server_address, port):
super().set_dc(dc_id, server_address, port)
self._update_session_table()
sessions = self._db_query(self.Session).all()
session = sessions[0] if sessions else None
if session and session.auth_key:
self._auth_key = AuthKey(data=session.auth_key)
else:
self._auth_key = None
@MemorySession.auth_key.setter
def auth_key(self, value):
self._auth_key = value
self._update_session_table()
def _update_session_table(self):
self.Session.query.filter(
self.Session.session_id == self.session_id).delete()
new = self.Session(session_id=self.session_id, dc_id=self._dc_id,
server_address=self._server_address,
port=self._port,
auth_key=(self._auth_key.key
if self._auth_key else b''))
self.db.merge(new)
def _db_query(self, dbclass, *args):
return dbclass.query.filter(
dbclass.session_id == self.session_id, *args
)
def save(self):
self.container.save()
def close(self):
# Nothing to do here, connection is managed by AlchemySessionContainer.
pass
def delete(self):
self._db_query(self.Session).delete()
self._db_query(self.Entity).delete()
self._db_query(self.SentFile).delete()
def _entity_values_to_row(self, id, hash, username, phone, name):
return self.Entity(session_id=self.session_id, id=id, hash=hash,
username=username, phone=phone, name=name)
def process_entities(self, tlo):
rows = self._entities_to_rows(tlo)
if not rows:
return
for row in rows:
self.db.merge(row)
self.save()
def get_entity_rows_by_phone(self, key):
row = self._db_query(self.Entity,
self.Entity.phone == key).one_or_none()
return (row.id, row.hash) if row else None
def get_entity_rows_by_username(self, key):
row = self._db_query(self.Entity,
self.Entity.username == key).one_or_none()
return (row.id, row.hash) if row else None
def get_entity_rows_by_name(self, key):
row = self._db_query(self.Entity,
self.Entity.name == key).one_or_none()
return (row.id, row.hash) if row else None
def get_entity_rows_by_id(self, key, exact=True):
if exact:
query = self._db_query(self.Entity, self.Entity.id == key)
else:
ids = (
utils.get_peer_id(PeerUser(key)),
utils.get_peer_id(PeerChat(key)),
utils.get_peer_id(PeerChannel(key))
)
query = self._db_query(self.Entity, self.Entity.id in ids)
row = query.one_or_none()
return (row.id, row.hash) if row else None
def get_file(self, md5_digest, file_size, cls):
row = self._db_query(self.SentFile,
self.SentFile.md5_digest == md5_digest,
self.SentFile.file_size == file_size,
self.SentFile.type == _SentFileType.from_type(
cls).value).one_or_none()
return (row.id, row.hash) if row else None
def cache_file(self, md5_digest, file_size, instance):
if not isinstance(instance, (InputDocument, InputPhoto)):
raise TypeError('Cannot cache %s instance' % type(instance))
self.db.merge(
self.SentFile(session_id=self.session_id, md5_digest=md5_digest,
type=_SentFileType.from_type(type(instance)).value,
id=instance.id, hash=instance.access_hash))
self.save()