mirror of
https://github.com/LonamiWebs/Telethon.git
synced 2024-11-25 19:03:46 +03:00
Fix remaining upgraded uses of the session to work correctly
This commit is contained in:
parent
d33402f02e
commit
9479e215fb
|
@ -319,6 +319,9 @@ async def connect(self: 'TelegramClient') -> None:
|
|||
# TODO Get state from channels too
|
||||
self._state_cache = statecache.StateCache(state, self._log)
|
||||
|
||||
# Use known key, if any
|
||||
self._sender.auth_key.key = dc.auth
|
||||
|
||||
if not await self._sender.connect(self._connection(
|
||||
str(ipaddress.ip_address(dc.ipv6 or dc.ipv4)),
|
||||
dc.port,
|
||||
|
@ -330,8 +333,8 @@ async def connect(self: 'TelegramClient') -> None:
|
|||
# We don't want to init or modify anything if we were already connected
|
||||
return
|
||||
|
||||
if self._sender.auth_key.key != dc.key:
|
||||
dc.key = self._sender.auth_key.key
|
||||
if self._sender.auth_key.key != dc.auth:
|
||||
dc.auth = self._sender.auth_key.key
|
||||
await self.session.insert_dc(dc)
|
||||
await self.session.save()
|
||||
|
||||
|
|
|
@ -271,19 +271,18 @@ async def get_input_entity(
|
|||
|
||||
# No InputPeer, cached peer, or known string. Fetch from session cache
|
||||
try:
|
||||
peer = utils.get_peer(peer)
|
||||
if isinstance(peer, _tl.PeerUser):
|
||||
entity = await self.session.get_entity(Entity.USER, peer.user_id)
|
||||
if entity:
|
||||
return _tl.InputPeerUser(entity.id, entity.access_hash)
|
||||
elif isinstance(peer, _tl.PeerChat):
|
||||
return _tl.InputPeerChat(peer.chat_id)
|
||||
elif isinstance(peer, _tl.PeerChannel):
|
||||
entity = await self.session.get_entity(Entity.CHANNEL, peer.user_id)
|
||||
if entity:
|
||||
return _tl.InputPeerChannel(entity.id, entity.access_hash)
|
||||
except ValueError:
|
||||
peer_id = utils.get_peer_id(peer)
|
||||
except TypeError:
|
||||
pass
|
||||
else:
|
||||
entity = await self.session.get_entity(None, peer_id)
|
||||
if entity:
|
||||
if entity.ty in (Entity.USER, Entity.BOT):
|
||||
return _tl.InputPeerUser(entity.id, entity.access_hash)
|
||||
elif entity.ty in (Entity.GROUP):
|
||||
return _tl.InputPeerChat(peer.chat_id)
|
||||
elif entity.ty in (Entity.CHANNEL, Entity.MEGAGROUP, Entity.GIGAGROUP):
|
||||
return _tl.InputPeerChannel(entity.id, entity.access_hash)
|
||||
|
||||
# Only network left to try
|
||||
if isinstance(peer, str):
|
||||
|
|
|
@ -59,7 +59,7 @@ class Session(ABC):
|
|||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def get_entity(self, ty: int, id: int) -> Optional[Entity]:
|
||||
async def get_entity(self, ty: Optional[int], id: int) -> Optional[Entity]:
|
||||
"""
|
||||
Get the `Entity` with matching ``ty`` and ``id``.
|
||||
|
||||
|
@ -75,6 +75,8 @@ class Session(ABC):
|
|||
the corresponding ``access_hash`` should still be returned.
|
||||
|
||||
You may use `types.canonical_entity_type` to find out the canonical type.
|
||||
|
||||
A ``ty`` with the value of ``None`` should be treated as "any entity with matching ID".
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
|
|
@ -36,7 +36,7 @@ class MemorySession(Session):
|
|||
async def insert_entities(self, entities: List[Entity]):
|
||||
self.entities.update((e.id, (e.ty, e.access_hash)) for e in entities)
|
||||
|
||||
async def get_entity(self, ty: int, id: int) -> Optional[Entity]:
|
||||
async def get_entity(self, ty: Optional[int], id: int) -> Optional[Entity]:
|
||||
try:
|
||||
ty, access_hash = self.entities[id]
|
||||
return Entity(ty, id, access_hash)
|
||||
|
|
|
@ -55,13 +55,17 @@ class SQLiteSession(Session):
|
|||
self._upgrade_database(old=version)
|
||||
c.execute("delete from version")
|
||||
c.execute("insert into version values (?)", (CURRENT_VERSION,))
|
||||
self.save()
|
||||
self._conn.commit()
|
||||
else:
|
||||
# Tables don't exist, create new ones
|
||||
self._create_table(c, 'version (version integer primary key)')
|
||||
self._mk_tables(c)
|
||||
c.execute("insert into version values (?)", (CURRENT_VERSION,))
|
||||
c.close()
|
||||
self.save()
|
||||
self._conn.commit()
|
||||
|
||||
# Must have committed or else the version will not have been updated while new tables
|
||||
# exist, leading to a half-upgraded state.
|
||||
c.close()
|
||||
|
||||
def _upgrade_database(self, old):
|
||||
c = self._cursor()
|
||||
|
@ -146,9 +150,6 @@ class SQLiteSession(Session):
|
|||
def _mk_tables(self, c):
|
||||
self._create_table(
|
||||
c,
|
||||
'''version (
|
||||
version integer primary key
|
||||
)''',
|
||||
'''datacenter (
|
||||
id integer primary key,
|
||||
ip text not null,
|
||||
|
@ -243,7 +244,7 @@ class SQLiteSession(Session):
|
|||
finally:
|
||||
c.close()
|
||||
|
||||
async def get_entity(self, ty: int, id: int) -> Optional[Entity]:
|
||||
async def get_entity(self, ty: Optional[int], id: int) -> Optional[Entity]:
|
||||
row = self._execute('select ty, id, access_hash from entity where id = ?', id)
|
||||
return Entity(*row) if row else None
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user