mirror of
https://github.com/LonamiWebs/Telethon.git
synced 2025-02-17 03:51:05 +03:00
Fix SQLAlchemy implementation
This commit is contained in:
parent
dc2229fdba
commit
c1a8896faa
|
@ -1,3 +1,4 @@
|
||||||
from .abstract import Session
|
from .abstract import Session
|
||||||
from .memory import MemorySession
|
from .memory import MemorySession
|
||||||
from .sqlite import SQLiteSession
|
from .sqlite import SQLiteSession
|
||||||
|
from .sqlalchemy import AlchemySessionContainer, AlchemySession
|
||||||
|
|
|
@ -1,85 +1,95 @@
|
||||||
from sqlalchemy.ext.declarative import declarative_base
|
from sqlalchemy.ext.declarative import declarative_base
|
||||||
from sqlalchemy import Column, String, Integer, Blob, orm
|
from sqlalchemy import Column, String, Integer, BLOB, orm
|
||||||
import sqlalchemy as sql
|
import sqlalchemy as sql
|
||||||
|
|
||||||
|
from ..crypto import AuthKey
|
||||||
from ..tl.types import InputPhoto, InputDocument
|
from ..tl.types import InputPhoto, InputDocument
|
||||||
|
|
||||||
from .memory import MemorySession, _SentFileType
|
from .memory import MemorySession, _SentFileType
|
||||||
|
|
||||||
Base = declarative_base()
|
|
||||||
LATEST_VERSION = 1
|
LATEST_VERSION = 1
|
||||||
|
|
||||||
|
|
||||||
class DBVersion(Base):
|
|
||||||
__tablename__ = "version"
|
|
||||||
version = Column(Integer, primary_key=True)
|
|
||||||
|
|
||||||
|
|
||||||
class DBSession(Base):
|
|
||||||
__tablename__ = "sessions"
|
|
||||||
|
|
||||||
session_id = Column(String, primary_key=True)
|
|
||||||
dc_id = Column(Integer, primary_key=True)
|
|
||||||
server_address = Column(String)
|
|
||||||
port = Column(Integer)
|
|
||||||
auth_key = Column(Blob)
|
|
||||||
|
|
||||||
|
|
||||||
class DBEntity(Base):
|
|
||||||
__tablename__ = "entities"
|
|
||||||
|
|
||||||
session_id = Column(String, primary_key=True)
|
|
||||||
id = Column(Integer, primary_key=True)
|
|
||||||
hash = Column(Integer, nullable=False)
|
|
||||||
username = Column(String)
|
|
||||||
phone = Column(Integer)
|
|
||||||
name = Column(String)
|
|
||||||
|
|
||||||
|
|
||||||
class DBSentFile(Base):
|
|
||||||
__tablename__ = "sent_files"
|
|
||||||
|
|
||||||
session_id = Column(String, primary_key=True)
|
|
||||||
md5_digest = Column(Blob, primary_key=True)
|
|
||||||
file_size = Column(Integer, primary_key=True)
|
|
||||||
type = Column(Integer, primary_key=True)
|
|
||||||
id = Column(Integer)
|
|
||||||
hash = Column(Integer)
|
|
||||||
|
|
||||||
|
|
||||||
class AlchemySessionContainer:
|
class AlchemySessionContainer:
|
||||||
def __init__(self, database):
|
def __init__(self, engine=None, session=None, table_prefix="",
|
||||||
if isinstance(database, str):
|
table_base=None, manage_tables=True):
|
||||||
database = sql.create_engine(database)
|
if isinstance(engine, str):
|
||||||
|
engine = sql.create_engine(engine)
|
||||||
|
|
||||||
self.db_engine = database
|
self.db_engine = engine
|
||||||
db_factory = orm.sessionmaker(bind=self.db_engine)
|
if not session:
|
||||||
self.db = orm.scoping.scoped_session(db_factory)
|
db_factory = orm.sessionmaker(bind=self.db_engine)
|
||||||
|
self.db = orm.scoping.scoped_session(db_factory)
|
||||||
if not self.db_engine.dialect.has_table(self.db_engine,
|
|
||||||
DBVersion.__tablename__):
|
|
||||||
Base.metadata.create_all(bind=self.db_engine)
|
|
||||||
self.db.add(DBVersion(version=LATEST_VERSION))
|
|
||||||
self.db.commit()
|
|
||||||
else:
|
else:
|
||||||
self.check_and_upgrade_database()
|
self.db = session
|
||||||
|
|
||||||
DBVersion.query = self.db.query_property()
|
table_base = table_base or declarative_base()
|
||||||
DBSession.query = self.db.query_property()
|
(self.Version, self.Session, self.Entity,
|
||||||
DBEntity.query = self.db.query_property()
|
self.SentFile) = self.create_table_classes(self.db, table_prefix,
|
||||||
DBSentFile.query = self.db.query_property()
|
table_base)
|
||||||
|
|
||||||
|
if manage_tables:
|
||||||
|
table_base.metadata.bind = self.db_engine
|
||||||
|
if not self.db_engine.dialect.has_table(self.db_engine,
|
||||||
|
self.Version.__tablename__):
|
||||||
|
table_base.metadata.create_all()
|
||||||
|
self.db.add(self.Version(version=LATEST_VERSION))
|
||||||
|
self.db.commit()
|
||||||
|
else:
|
||||||
|
self.check_and_upgrade_database()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def create_table_classes(db, prefix, Base):
|
||||||
|
class Version(Base):
|
||||||
|
query = db.query_property()
|
||||||
|
__tablename__ = "{prefix}version".format(prefix=prefix)
|
||||||
|
version = Column(Integer, primary_key=True)
|
||||||
|
|
||||||
|
class Session(Base):
|
||||||
|
query = db.query_property()
|
||||||
|
__tablename__ = "{prefix}sessions".format(prefix=prefix)
|
||||||
|
|
||||||
|
session_id = Column(String, primary_key=True)
|
||||||
|
dc_id = Column(Integer, primary_key=True)
|
||||||
|
server_address = Column(String)
|
||||||
|
port = Column(Integer)
|
||||||
|
auth_key = Column(BLOB)
|
||||||
|
|
||||||
|
class Entity(Base):
|
||||||
|
query = db.query_property()
|
||||||
|
__tablename__ = "{prefix}entities".format(prefix=prefix)
|
||||||
|
|
||||||
|
session_id = Column(String, primary_key=True)
|
||||||
|
id = Column(Integer, primary_key=True)
|
||||||
|
hash = Column(Integer, nullable=False)
|
||||||
|
username = Column(String)
|
||||||
|
phone = Column(Integer)
|
||||||
|
name = Column(String)
|
||||||
|
|
||||||
|
class SentFile(Base):
|
||||||
|
query = db.query_property()
|
||||||
|
__tablename__ = "{prefix}sent_files".format(prefix=prefix)
|
||||||
|
|
||||||
|
session_id = Column(String, primary_key=True)
|
||||||
|
md5_digest = Column(BLOB, primary_key=True)
|
||||||
|
file_size = Column(Integer, primary_key=True)
|
||||||
|
type = Column(Integer, primary_key=True)
|
||||||
|
id = Column(Integer)
|
||||||
|
hash = Column(Integer)
|
||||||
|
|
||||||
|
return Version, Session, Entity, SentFile
|
||||||
|
|
||||||
def check_and_upgrade_database(self):
|
def check_and_upgrade_database(self):
|
||||||
row = DBVersion.query.get()
|
row = self.Version.query.all()
|
||||||
version = row.version if row else 1
|
version = row[0].version if row else 1
|
||||||
if version == LATEST_VERSION:
|
if version == LATEST_VERSION:
|
||||||
return
|
return
|
||||||
|
|
||||||
DBVersion.query.delete()
|
self.Version.query.delete()
|
||||||
|
|
||||||
# Implement table schema updates here and increase version
|
# Implement table schema updates here and increase version
|
||||||
|
|
||||||
self.db.add(DBVersion(version=version))
|
self.db.add(self.Version(version=version))
|
||||||
self.db.commit()
|
self.db.commit()
|
||||||
|
|
||||||
def new_session(self, session_id):
|
def new_session(self, session_id):
|
||||||
|
@ -97,7 +107,20 @@ class AlchemySession(MemorySession):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.container = container
|
self.container = container
|
||||||
self.db = container.db
|
self.db = container.db
|
||||||
|
self.Version, self.Session, self.Entity, self.SentFile = (
|
||||||
|
container.Version, container.Session, container.Entity,
|
||||||
|
container.SentFile)
|
||||||
self.session_id = session_id
|
self.session_id = session_id
|
||||||
|
self.load_session()
|
||||||
|
|
||||||
|
def load_session(self):
|
||||||
|
sessions = self._db_query(self.Session).all()
|
||||||
|
session = sessions[0] if sessions else None
|
||||||
|
if session:
|
||||||
|
self._dc_id = session.dc_id
|
||||||
|
self._server_address = session.server_address
|
||||||
|
self._port = session.port
|
||||||
|
self._auth_key = AuthKey(data=session.auth_key)
|
||||||
|
|
||||||
def clone(self, to_instance=None):
|
def clone(self, to_instance=None):
|
||||||
cloned = to_instance or self.__class__(self.container, self.session_id)
|
cloned = to_instance or self.__class__(self.container, self.session_id)
|
||||||
|
@ -107,17 +130,18 @@ class AlchemySession(MemorySession):
|
||||||
super().set_dc(dc_id, server_address, port)
|
super().set_dc(dc_id, server_address, port)
|
||||||
|
|
||||||
def _update_session_table(self):
|
def _update_session_table(self):
|
||||||
self.db.query(DBSession).filter(
|
self.Session.query.filter(
|
||||||
DBSession.session_id == self.session_id).delete()
|
self.Session.session_id == self.session_id).delete()
|
||||||
new = DBSession(session_id=self.session_id, dc_id=self._dc_id,
|
new = self.Session(session_id=self.session_id, dc_id=self._dc_id,
|
||||||
server_address=self._server_address, port=self._port,
|
server_address=self._server_address,
|
||||||
auth_key=self._auth_key.key if self._auth_key else b'')
|
port=self._port,
|
||||||
|
auth_key=(self._auth_key.key
|
||||||
|
if self._auth_key else b''))
|
||||||
self.db.merge(new)
|
self.db.merge(new)
|
||||||
|
|
||||||
def _db_query(self, dbclass, *args):
|
def _db_query(self, dbclass, *args):
|
||||||
return self.db.query(dbclass).filter(
|
return dbclass.query.filter(dbclass.session_id == self.session_id,
|
||||||
dbclass.session_id == self.session_id,
|
*args)
|
||||||
*args)
|
|
||||||
|
|
||||||
def save(self):
|
def save(self):
|
||||||
self.container.save()
|
self.container.save()
|
||||||
|
@ -127,42 +151,47 @@ class AlchemySession(MemorySession):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def delete(self):
|
def delete(self):
|
||||||
self._db_query(DBSession).delete()
|
self._db_query(self.Session).delete()
|
||||||
self._db_query(DBEntity).delete()
|
self._db_query(self.Entity).delete()
|
||||||
self._db_query(DBSentFile).delete()
|
self._db_query(self.SentFile).delete()
|
||||||
|
|
||||||
def _entity_values_to_row(self, id, hash, username, phone, name):
|
def _entity_values_to_row(self, id, hash, username, phone, name):
|
||||||
return DBEntity(session_id=self.session_id, id=id, hash=hash,
|
return self.Entity(session_id=self.session_id, id=id, hash=hash,
|
||||||
username=username, phone=phone, name=name)
|
username=username, phone=phone, name=name)
|
||||||
|
|
||||||
def process_entities(self, tlo):
|
def process_entities(self, tlo):
|
||||||
rows = self._entities_to_rows(tlo)
|
rows = self._entities_to_rows(tlo)
|
||||||
if not rows:
|
if not rows:
|
||||||
return
|
return
|
||||||
|
|
||||||
self.db.add_all(rows)
|
for row in rows:
|
||||||
|
self.db.merge(row)
|
||||||
self.save()
|
self.save()
|
||||||
|
|
||||||
def get_entity_rows_by_phone(self, key):
|
def get_entity_rows_by_phone(self, key):
|
||||||
row = self._db_query(DBEntity, DBEntity.phone == key).one_or_none()
|
row = self._db_query(self.Entity,
|
||||||
|
self.Entity.phone == key).one_or_none()
|
||||||
return row.id, row.hash if row else None
|
return row.id, row.hash if row else None
|
||||||
|
|
||||||
def get_entity_rows_by_username(self, key):
|
def get_entity_rows_by_username(self, key):
|
||||||
row = self._db_query(DBEntity, DBEntity.username == key).one_or_none()
|
row = self._db_query(self.Entity,
|
||||||
|
self.Entity.username == key).one_or_none()
|
||||||
return row.id, row.hash if row else None
|
return row.id, row.hash if row else None
|
||||||
|
|
||||||
def get_entity_rows_by_name(self, key):
|
def get_entity_rows_by_name(self, key):
|
||||||
row = self._db_query(DBEntity, DBEntity.name == key).one_or_none()
|
row = self._db_query(self.Entity,
|
||||||
|
self.Entity.name == key).one_or_none()
|
||||||
return row.id, row.hash if row else None
|
return row.id, row.hash if row else None
|
||||||
|
|
||||||
def get_entity_rows_by_id(self, key):
|
def get_entity_rows_by_id(self, key):
|
||||||
row = self._db_query(DBEntity, DBEntity.id == key).one_or_none()
|
row = self._db_query(self.Entity, self.Entity.id == key).one_or_none()
|
||||||
return row.id, row.hash if row else None
|
return row.id, row.hash if row else None
|
||||||
|
|
||||||
def get_file(self, md5_digest, file_size, cls):
|
def get_file(self, md5_digest, file_size, cls):
|
||||||
row = self._db_query(DBSentFile, DBSentFile.md5_digest == md5_digest,
|
row = self._db_query(self.SentFile,
|
||||||
DBSentFile.file_size == file_size,
|
self.SentFile.md5_digest == md5_digest,
|
||||||
DBSentFile.type == _SentFileType.from_type(
|
self.SentFile.file_size == file_size,
|
||||||
|
self.SentFile.type == _SentFileType.from_type(
|
||||||
cls).value).one_or_none()
|
cls).value).one_or_none()
|
||||||
return row.id, row.hash if row else None
|
return row.id, row.hash if row else None
|
||||||
|
|
||||||
|
@ -171,7 +200,7 @@ class AlchemySession(MemorySession):
|
||||||
raise TypeError('Cannot cache %s instance' % type(instance))
|
raise TypeError('Cannot cache %s instance' % type(instance))
|
||||||
|
|
||||||
self.db.merge(
|
self.db.merge(
|
||||||
DBSentFile(session_id=self.session_id, md5_digest=md5_digest,
|
self.SentFile(session_id=self.session_id, md5_digest=md5_digest,
|
||||||
type=_SentFileType.from_type(type(instance)).value,
|
type=_SentFileType.from_type(type(instance)).value,
|
||||||
id=instance.id, hash=instance.access_hash))
|
id=instance.id, hash=instance.access_hash))
|
||||||
self.save()
|
self.save()
|
||||||
|
|
Loading…
Reference in New Issue
Block a user