Fix remaining upgraded uses of the session to work correctly

This commit is contained in:
Lonami Exo 2021-09-19 17:08:51 +02:00
parent d33402f02e
commit 9479e215fb
5 changed files with 28 additions and 23 deletions

View File

@ -319,6 +319,9 @@ async def connect(self: 'TelegramClient') -> None:
# TODO Get state from channels too # TODO Get state from channels too
self._state_cache = statecache.StateCache(state, self._log) self._state_cache = statecache.StateCache(state, self._log)
# Use known key, if any
self._sender.auth_key.key = dc.auth
if not await self._sender.connect(self._connection( if not await self._sender.connect(self._connection(
str(ipaddress.ip_address(dc.ipv6 or dc.ipv4)), str(ipaddress.ip_address(dc.ipv6 or dc.ipv4)),
dc.port, dc.port,
@ -330,8 +333,8 @@ async def connect(self: 'TelegramClient') -> None:
# We don't want to init or modify anything if we were already connected # We don't want to init or modify anything if we were already connected
return return
if self._sender.auth_key.key != dc.key: if self._sender.auth_key.key != dc.auth:
dc.key = self._sender.auth_key.key dc.auth = self._sender.auth_key.key
await self.session.insert_dc(dc) await self.session.insert_dc(dc)
await self.session.save() await self.session.save()

View File

@ -271,19 +271,18 @@ async def get_input_entity(
# No InputPeer, cached peer, or known string. Fetch from session cache # No InputPeer, cached peer, or known string. Fetch from session cache
try: try:
peer = utils.get_peer(peer) peer_id = utils.get_peer_id(peer)
if isinstance(peer, _tl.PeerUser): except TypeError:
entity = await self.session.get_entity(Entity.USER, peer.user_id)
if entity:
return _tl.InputPeerUser(entity.id, entity.access_hash)
elif isinstance(peer, _tl.PeerChat):
return _tl.InputPeerChat(peer.chat_id)
elif isinstance(peer, _tl.PeerChannel):
entity = await self.session.get_entity(Entity.CHANNEL, peer.user_id)
if entity:
return _tl.InputPeerChannel(entity.id, entity.access_hash)
except ValueError:
pass pass
else:
entity = await self.session.get_entity(None, peer_id)
if entity:
if entity.ty in (Entity.USER, Entity.BOT):
return _tl.InputPeerUser(entity.id, entity.access_hash)
elif entity.ty in (Entity.GROUP):
return _tl.InputPeerChat(peer.chat_id)
elif entity.ty in (Entity.CHANNEL, Entity.MEGAGROUP, Entity.GIGAGROUP):
return _tl.InputPeerChannel(entity.id, entity.access_hash)
# Only network left to try # Only network left to try
if isinstance(peer, str): if isinstance(peer, str):

View File

@ -59,7 +59,7 @@ class Session(ABC):
raise NotImplementedError raise NotImplementedError
@abstractmethod @abstractmethod
async def get_entity(self, ty: int, id: int) -> Optional[Entity]: async def get_entity(self, ty: Optional[int], id: int) -> Optional[Entity]:
""" """
Get the `Entity` with matching ``ty`` and ``id``. Get the `Entity` with matching ``ty`` and ``id``.
@ -75,6 +75,8 @@ class Session(ABC):
the corresponding ``access_hash`` should still be returned. the corresponding ``access_hash`` should still be returned.
You may use `types.canonical_entity_type` to find out the canonical type. You may use `types.canonical_entity_type` to find out the canonical type.
A ``ty`` with the value of ``None`` should be treated as "any entity with matching ID".
""" """
raise NotImplementedError raise NotImplementedError

View File

@ -36,7 +36,7 @@ class MemorySession(Session):
async def insert_entities(self, entities: List[Entity]): async def insert_entities(self, entities: List[Entity]):
self.entities.update((e.id, (e.ty, e.access_hash)) for e in entities) self.entities.update((e.id, (e.ty, e.access_hash)) for e in entities)
async def get_entity(self, ty: int, id: int) -> Optional[Entity]: async def get_entity(self, ty: Optional[int], id: int) -> Optional[Entity]:
try: try:
ty, access_hash = self.entities[id] ty, access_hash = self.entities[id]
return Entity(ty, id, access_hash) return Entity(ty, id, access_hash)

View File

@ -55,13 +55,17 @@ class SQLiteSession(Session):
self._upgrade_database(old=version) self._upgrade_database(old=version)
c.execute("delete from version") c.execute("delete from version")
c.execute("insert into version values (?)", (CURRENT_VERSION,)) c.execute("insert into version values (?)", (CURRENT_VERSION,))
self.save() self._conn.commit()
else: else:
# Tables don't exist, create new ones # Tables don't exist, create new ones
self._create_table(c, 'version (version integer primary key)')
self._mk_tables(c) self._mk_tables(c)
c.execute("insert into version values (?)", (CURRENT_VERSION,)) c.execute("insert into version values (?)", (CURRENT_VERSION,))
c.close() self._conn.commit()
self.save()
# Must have committed or else the version will not have been updated while new tables
# exist, leading to a half-upgraded state.
c.close()
def _upgrade_database(self, old): def _upgrade_database(self, old):
c = self._cursor() c = self._cursor()
@ -146,9 +150,6 @@ class SQLiteSession(Session):
def _mk_tables(self, c): def _mk_tables(self, c):
self._create_table( self._create_table(
c, c,
'''version (
version integer primary key
)''',
'''datacenter ( '''datacenter (
id integer primary key, id integer primary key,
ip text not null, ip text not null,
@ -243,7 +244,7 @@ class SQLiteSession(Session):
finally: finally:
c.close() c.close()
async def get_entity(self, ty: int, id: int) -> Optional[Entity]: async def get_entity(self, ty: Optional[int], id: int) -> Optional[Entity]:
row = self._execute('select ty, id, access_hash from entity where id = ?', id) row = self._execute('select ty, id, access_hash from entity where id = ?', id)
return Entity(*row) if row else None return Entity(*row) if row else None