mirror of
https://github.com/LonamiWebs/Telethon.git
synced 2024-11-29 12:53:44 +03:00
Fix remaining upgraded uses of the session to work correctly
This commit is contained in:
parent
d33402f02e
commit
9479e215fb
|
@ -319,6 +319,9 @@ async def connect(self: 'TelegramClient') -> None:
|
||||||
# TODO Get state from channels too
|
# TODO Get state from channels too
|
||||||
self._state_cache = statecache.StateCache(state, self._log)
|
self._state_cache = statecache.StateCache(state, self._log)
|
||||||
|
|
||||||
|
# Use known key, if any
|
||||||
|
self._sender.auth_key.key = dc.auth
|
||||||
|
|
||||||
if not await self._sender.connect(self._connection(
|
if not await self._sender.connect(self._connection(
|
||||||
str(ipaddress.ip_address(dc.ipv6 or dc.ipv4)),
|
str(ipaddress.ip_address(dc.ipv6 or dc.ipv4)),
|
||||||
dc.port,
|
dc.port,
|
||||||
|
@ -330,8 +333,8 @@ async def connect(self: 'TelegramClient') -> None:
|
||||||
# We don't want to init or modify anything if we were already connected
|
# We don't want to init or modify anything if we were already connected
|
||||||
return
|
return
|
||||||
|
|
||||||
if self._sender.auth_key.key != dc.key:
|
if self._sender.auth_key.key != dc.auth:
|
||||||
dc.key = self._sender.auth_key.key
|
dc.auth = self._sender.auth_key.key
|
||||||
await self.session.insert_dc(dc)
|
await self.session.insert_dc(dc)
|
||||||
await self.session.save()
|
await self.session.save()
|
||||||
|
|
||||||
|
|
|
@ -271,19 +271,18 @@ async def get_input_entity(
|
||||||
|
|
||||||
# No InputPeer, cached peer, or known string. Fetch from session cache
|
# No InputPeer, cached peer, or known string. Fetch from session cache
|
||||||
try:
|
try:
|
||||||
peer = utils.get_peer(peer)
|
peer_id = utils.get_peer_id(peer)
|
||||||
if isinstance(peer, _tl.PeerUser):
|
except TypeError:
|
||||||
entity = await self.session.get_entity(Entity.USER, peer.user_id)
|
|
||||||
if entity:
|
|
||||||
return _tl.InputPeerUser(entity.id, entity.access_hash)
|
|
||||||
elif isinstance(peer, _tl.PeerChat):
|
|
||||||
return _tl.InputPeerChat(peer.chat_id)
|
|
||||||
elif isinstance(peer, _tl.PeerChannel):
|
|
||||||
entity = await self.session.get_entity(Entity.CHANNEL, peer.user_id)
|
|
||||||
if entity:
|
|
||||||
return _tl.InputPeerChannel(entity.id, entity.access_hash)
|
|
||||||
except ValueError:
|
|
||||||
pass
|
pass
|
||||||
|
else:
|
||||||
|
entity = await self.session.get_entity(None, peer_id)
|
||||||
|
if entity:
|
||||||
|
if entity.ty in (Entity.USER, Entity.BOT):
|
||||||
|
return _tl.InputPeerUser(entity.id, entity.access_hash)
|
||||||
|
elif entity.ty in (Entity.GROUP):
|
||||||
|
return _tl.InputPeerChat(peer.chat_id)
|
||||||
|
elif entity.ty in (Entity.CHANNEL, Entity.MEGAGROUP, Entity.GIGAGROUP):
|
||||||
|
return _tl.InputPeerChannel(entity.id, entity.access_hash)
|
||||||
|
|
||||||
# Only network left to try
|
# Only network left to try
|
||||||
if isinstance(peer, str):
|
if isinstance(peer, str):
|
||||||
|
|
|
@ -59,7 +59,7 @@ class Session(ABC):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def get_entity(self, ty: int, id: int) -> Optional[Entity]:
|
async def get_entity(self, ty: Optional[int], id: int) -> Optional[Entity]:
|
||||||
"""
|
"""
|
||||||
Get the `Entity` with matching ``ty`` and ``id``.
|
Get the `Entity` with matching ``ty`` and ``id``.
|
||||||
|
|
||||||
|
@ -75,6 +75,8 @@ class Session(ABC):
|
||||||
the corresponding ``access_hash`` should still be returned.
|
the corresponding ``access_hash`` should still be returned.
|
||||||
|
|
||||||
You may use `types.canonical_entity_type` to find out the canonical type.
|
You may use `types.canonical_entity_type` to find out the canonical type.
|
||||||
|
|
||||||
|
A ``ty`` with the value of ``None`` should be treated as "any entity with matching ID".
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
|
@ -36,7 +36,7 @@ class MemorySession(Session):
|
||||||
async def insert_entities(self, entities: List[Entity]):
|
async def insert_entities(self, entities: List[Entity]):
|
||||||
self.entities.update((e.id, (e.ty, e.access_hash)) for e in entities)
|
self.entities.update((e.id, (e.ty, e.access_hash)) for e in entities)
|
||||||
|
|
||||||
async def get_entity(self, ty: int, id: int) -> Optional[Entity]:
|
async def get_entity(self, ty: Optional[int], id: int) -> Optional[Entity]:
|
||||||
try:
|
try:
|
||||||
ty, access_hash = self.entities[id]
|
ty, access_hash = self.entities[id]
|
||||||
return Entity(ty, id, access_hash)
|
return Entity(ty, id, access_hash)
|
||||||
|
|
|
@ -55,13 +55,17 @@ class SQLiteSession(Session):
|
||||||
self._upgrade_database(old=version)
|
self._upgrade_database(old=version)
|
||||||
c.execute("delete from version")
|
c.execute("delete from version")
|
||||||
c.execute("insert into version values (?)", (CURRENT_VERSION,))
|
c.execute("insert into version values (?)", (CURRENT_VERSION,))
|
||||||
self.save()
|
self._conn.commit()
|
||||||
else:
|
else:
|
||||||
# Tables don't exist, create new ones
|
# Tables don't exist, create new ones
|
||||||
|
self._create_table(c, 'version (version integer primary key)')
|
||||||
self._mk_tables(c)
|
self._mk_tables(c)
|
||||||
c.execute("insert into version values (?)", (CURRENT_VERSION,))
|
c.execute("insert into version values (?)", (CURRENT_VERSION,))
|
||||||
c.close()
|
self._conn.commit()
|
||||||
self.save()
|
|
||||||
|
# Must have committed or else the version will not have been updated while new tables
|
||||||
|
# exist, leading to a half-upgraded state.
|
||||||
|
c.close()
|
||||||
|
|
||||||
def _upgrade_database(self, old):
|
def _upgrade_database(self, old):
|
||||||
c = self._cursor()
|
c = self._cursor()
|
||||||
|
@ -146,9 +150,6 @@ class SQLiteSession(Session):
|
||||||
def _mk_tables(self, c):
|
def _mk_tables(self, c):
|
||||||
self._create_table(
|
self._create_table(
|
||||||
c,
|
c,
|
||||||
'''version (
|
|
||||||
version integer primary key
|
|
||||||
)''',
|
|
||||||
'''datacenter (
|
'''datacenter (
|
||||||
id integer primary key,
|
id integer primary key,
|
||||||
ip text not null,
|
ip text not null,
|
||||||
|
@ -243,7 +244,7 @@ class SQLiteSession(Session):
|
||||||
finally:
|
finally:
|
||||||
c.close()
|
c.close()
|
||||||
|
|
||||||
async def get_entity(self, ty: int, id: int) -> Optional[Entity]:
|
async def get_entity(self, ty: Optional[int], id: int) -> Optional[Entity]:
|
||||||
row = self._execute('select ty, id, access_hash from entity where id = ?', id)
|
row = self._execute('select ty, id, access_hash from entity where id = ?', id)
|
||||||
return Entity(*row) if row else None
|
return Entity(*row) if row else None
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user