From c1a8896faa43dfa798fc9223283eae42dbf5afd6 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Fri, 2 Mar 2018 20:14:11 +0200 Subject: [PATCH] Fix SQLAlchemy implementation --- telethon/sessions/__init__.py | 1 + telethon/sessions/sqlalchemy.py | 197 ++++++++++++++++++-------------- 2 files changed, 114 insertions(+), 84 deletions(-) diff --git a/telethon/sessions/__init__.py b/telethon/sessions/__init__.py index af3423f3..a487a4bd 100644 --- a/telethon/sessions/__init__.py +++ b/telethon/sessions/__init__.py @@ -1,3 +1,4 @@ from .abstract import Session from .memory import MemorySession from .sqlite import SQLiteSession +from .sqlalchemy import AlchemySessionContainer, AlchemySession diff --git a/telethon/sessions/sqlalchemy.py b/telethon/sessions/sqlalchemy.py index 782f811f..aa618e4c 100644 --- a/telethon/sessions/sqlalchemy.py +++ b/telethon/sessions/sqlalchemy.py @@ -1,85 +1,95 @@ from sqlalchemy.ext.declarative import declarative_base -from sqlalchemy import Column, String, Integer, Blob, orm +from sqlalchemy import Column, String, Integer, BLOB, orm import sqlalchemy as sql +from ..crypto import AuthKey from ..tl.types import InputPhoto, InputDocument from .memory import MemorySession, _SentFileType -Base = declarative_base() LATEST_VERSION = 1 -class DBVersion(Base): - __tablename__ = "version" - version = Column(Integer, primary_key=True) - - -class DBSession(Base): - __tablename__ = "sessions" - - session_id = Column(String, primary_key=True) - dc_id = Column(Integer, primary_key=True) - server_address = Column(String) - port = Column(Integer) - auth_key = Column(Blob) - - -class DBEntity(Base): - __tablename__ = "entities" - - session_id = Column(String, primary_key=True) - id = Column(Integer, primary_key=True) - hash = Column(Integer, nullable=False) - username = Column(String) - phone = Column(Integer) - name = Column(String) - - -class DBSentFile(Base): - __tablename__ = "sent_files" - - session_id = Column(String, primary_key=True) - md5_digest = Column(Blob, primary_key=True) - file_size = Column(Integer, primary_key=True) - type = Column(Integer, primary_key=True) - id = Column(Integer) - hash = Column(Integer) - - class AlchemySessionContainer: - def __init__(self, database): - if isinstance(database, str): - database = sql.create_engine(database) + def __init__(self, engine=None, session=None, table_prefix="", + table_base=None, manage_tables=True): + if isinstance(engine, str): + engine = sql.create_engine(engine) - self.db_engine = database - db_factory = orm.sessionmaker(bind=self.db_engine) - self.db = orm.scoping.scoped_session(db_factory) - - if not self.db_engine.dialect.has_table(self.db_engine, - DBVersion.__tablename__): - Base.metadata.create_all(bind=self.db_engine) - self.db.add(DBVersion(version=LATEST_VERSION)) - self.db.commit() + self.db_engine = engine + if not session: + db_factory = orm.sessionmaker(bind=self.db_engine) + self.db = orm.scoping.scoped_session(db_factory) else: - self.check_and_upgrade_database() + self.db = session - DBVersion.query = self.db.query_property() - DBSession.query = self.db.query_property() - DBEntity.query = self.db.query_property() - DBSentFile.query = self.db.query_property() + table_base = table_base or declarative_base() + (self.Version, self.Session, self.Entity, + self.SentFile) = self.create_table_classes(self.db, table_prefix, + table_base) + + if manage_tables: + table_base.metadata.bind = self.db_engine + if not self.db_engine.dialect.has_table(self.db_engine, + self.Version.__tablename__): + table_base.metadata.create_all() + self.db.add(self.Version(version=LATEST_VERSION)) + self.db.commit() + else: + self.check_and_upgrade_database() + + @staticmethod + def create_table_classes(db, prefix, Base): + class Version(Base): + query = db.query_property() + __tablename__ = "{prefix}version".format(prefix=prefix) + version = Column(Integer, primary_key=True) + + class Session(Base): + query = db.query_property() + __tablename__ = "{prefix}sessions".format(prefix=prefix) + + session_id = Column(String, primary_key=True) + dc_id = Column(Integer, primary_key=True) + server_address = Column(String) + port = Column(Integer) + auth_key = Column(BLOB) + + class Entity(Base): + query = db.query_property() + __tablename__ = "{prefix}entities".format(prefix=prefix) + + session_id = Column(String, primary_key=True) + id = Column(Integer, primary_key=True) + hash = Column(Integer, nullable=False) + username = Column(String) + phone = Column(Integer) + name = Column(String) + + class SentFile(Base): + query = db.query_property() + __tablename__ = "{prefix}sent_files".format(prefix=prefix) + + session_id = Column(String, primary_key=True) + md5_digest = Column(BLOB, primary_key=True) + file_size = Column(Integer, primary_key=True) + type = Column(Integer, primary_key=True) + id = Column(Integer) + hash = Column(Integer) + + return Version, Session, Entity, SentFile def check_and_upgrade_database(self): - row = DBVersion.query.get() - version = row.version if row else 1 + row = self.Version.query.all() + version = row[0].version if row else 1 if version == LATEST_VERSION: return - DBVersion.query.delete() + self.Version.query.delete() # Implement table schema updates here and increase version - self.db.add(DBVersion(version=version)) + self.db.add(self.Version(version=version)) self.db.commit() def new_session(self, session_id): @@ -97,7 +107,20 @@ class AlchemySession(MemorySession): super().__init__() self.container = container self.db = container.db + self.Version, self.Session, self.Entity, self.SentFile = ( + container.Version, container.Session, container.Entity, + container.SentFile) self.session_id = session_id + self.load_session() + + def load_session(self): + sessions = self._db_query(self.Session).all() + session = sessions[0] if sessions else None + if session: + self._dc_id = session.dc_id + self._server_address = session.server_address + self._port = session.port + self._auth_key = AuthKey(data=session.auth_key) def clone(self, to_instance=None): cloned = to_instance or self.__class__(self.container, self.session_id) @@ -107,17 +130,18 @@ class AlchemySession(MemorySession): super().set_dc(dc_id, server_address, port) def _update_session_table(self): - self.db.query(DBSession).filter( - DBSession.session_id == self.session_id).delete() - new = DBSession(session_id=self.session_id, dc_id=self._dc_id, - server_address=self._server_address, port=self._port, - auth_key=self._auth_key.key if self._auth_key else b'') + self.Session.query.filter( + self.Session.session_id == self.session_id).delete() + new = self.Session(session_id=self.session_id, dc_id=self._dc_id, + server_address=self._server_address, + port=self._port, + auth_key=(self._auth_key.key + if self._auth_key else b'')) self.db.merge(new) def _db_query(self, dbclass, *args): - return self.db.query(dbclass).filter( - dbclass.session_id == self.session_id, - *args) + return dbclass.query.filter(dbclass.session_id == self.session_id, + *args) def save(self): self.container.save() @@ -127,42 +151,47 @@ class AlchemySession(MemorySession): pass def delete(self): - self._db_query(DBSession).delete() - self._db_query(DBEntity).delete() - self._db_query(DBSentFile).delete() + self._db_query(self.Session).delete() + self._db_query(self.Entity).delete() + self._db_query(self.SentFile).delete() def _entity_values_to_row(self, id, hash, username, phone, name): - return DBEntity(session_id=self.session_id, id=id, hash=hash, - username=username, phone=phone, name=name) + return self.Entity(session_id=self.session_id, id=id, hash=hash, + username=username, phone=phone, name=name) def process_entities(self, tlo): rows = self._entities_to_rows(tlo) if not rows: return - self.db.add_all(rows) + for row in rows: + self.db.merge(row) self.save() def get_entity_rows_by_phone(self, key): - row = self._db_query(DBEntity, DBEntity.phone == key).one_or_none() + row = self._db_query(self.Entity, + self.Entity.phone == key).one_or_none() return row.id, row.hash if row else None def get_entity_rows_by_username(self, key): - row = self._db_query(DBEntity, DBEntity.username == key).one_or_none() + row = self._db_query(self.Entity, + self.Entity.username == key).one_or_none() return row.id, row.hash if row else None def get_entity_rows_by_name(self, key): - row = self._db_query(DBEntity, DBEntity.name == key).one_or_none() + row = self._db_query(self.Entity, + self.Entity.name == key).one_or_none() return row.id, row.hash if row else None def get_entity_rows_by_id(self, key): - row = self._db_query(DBEntity, DBEntity.id == key).one_or_none() + row = self._db_query(self.Entity, self.Entity.id == key).one_or_none() return row.id, row.hash if row else None def get_file(self, md5_digest, file_size, cls): - row = self._db_query(DBSentFile, DBSentFile.md5_digest == md5_digest, - DBSentFile.file_size == file_size, - DBSentFile.type == _SentFileType.from_type( + row = self._db_query(self.SentFile, + self.SentFile.md5_digest == md5_digest, + self.SentFile.file_size == file_size, + self.SentFile.type == _SentFileType.from_type( cls).value).one_or_none() return row.id, row.hash if row else None @@ -171,7 +200,7 @@ class AlchemySession(MemorySession): raise TypeError('Cannot cache %s instance' % type(instance)) self.db.merge( - DBSentFile(session_id=self.session_id, md5_digest=md5_digest, - type=_SentFileType.from_type(type(instance)).value, - id=instance.id, hash=instance.access_hash)) + self.SentFile(session_id=self.session_id, md5_digest=md5_digest, + type=_SentFileType.from_type(type(instance)).value, + id=instance.id, hash=instance.access_hash)) self.save()