diff --git a/telethon/_client/telegrambaseclient.py b/telethon/_client/telegrambaseclient.py index 1050c2a5..1add1abb 100644 --- a/telethon/_client/telegrambaseclient.py +++ b/telethon/_client/telegrambaseclient.py @@ -319,6 +319,9 @@ async def connect(self: 'TelegramClient') -> None: # TODO Get state from channels too self._state_cache = statecache.StateCache(state, self._log) + # Use known key, if any + self._sender.auth_key.key = dc.auth + if not await self._sender.connect(self._connection( str(ipaddress.ip_address(dc.ipv6 or dc.ipv4)), dc.port, @@ -330,8 +333,8 @@ async def connect(self: 'TelegramClient') -> None: # We don't want to init or modify anything if we were already connected return - if self._sender.auth_key.key != dc.key: - dc.key = self._sender.auth_key.key + if self._sender.auth_key.key != dc.auth: + dc.auth = self._sender.auth_key.key await self.session.insert_dc(dc) await self.session.save() diff --git a/telethon/_client/users.py b/telethon/_client/users.py index 394baee9..18a6f0f4 100644 --- a/telethon/_client/users.py +++ b/telethon/_client/users.py @@ -271,19 +271,18 @@ async def get_input_entity( # No InputPeer, cached peer, or known string. Fetch from session cache try: - peer = utils.get_peer(peer) - if isinstance(peer, _tl.PeerUser): - entity = await self.session.get_entity(Entity.USER, peer.user_id) - if entity: - return _tl.InputPeerUser(entity.id, entity.access_hash) - elif isinstance(peer, _tl.PeerChat): - return _tl.InputPeerChat(peer.chat_id) - elif isinstance(peer, _tl.PeerChannel): - entity = await self.session.get_entity(Entity.CHANNEL, peer.user_id) - if entity: - return _tl.InputPeerChannel(entity.id, entity.access_hash) - except ValueError: + peer_id = utils.get_peer_id(peer) + except TypeError: pass + else: + entity = await self.session.get_entity(None, peer_id) + if entity: + if entity.ty in (Entity.USER, Entity.BOT): + return _tl.InputPeerUser(entity.id, entity.access_hash) + elif entity.ty in (Entity.GROUP): + return _tl.InputPeerChat(peer.chat_id) + elif entity.ty in (Entity.CHANNEL, Entity.MEGAGROUP, Entity.GIGAGROUP): + return _tl.InputPeerChannel(entity.id, entity.access_hash) # Only network left to try if isinstance(peer, str): diff --git a/telethon/sessions/abstract.py b/telethon/sessions/abstract.py index 4cdc9131..2b28ae76 100644 --- a/telethon/sessions/abstract.py +++ b/telethon/sessions/abstract.py @@ -59,7 +59,7 @@ class Session(ABC): raise NotImplementedError @abstractmethod - async def get_entity(self, ty: int, id: int) -> Optional[Entity]: + async def get_entity(self, ty: Optional[int], id: int) -> Optional[Entity]: """ Get the `Entity` with matching ``ty`` and ``id``. @@ -75,6 +75,8 @@ class Session(ABC): the corresponding ``access_hash`` should still be returned. You may use `types.canonical_entity_type` to find out the canonical type. + + A ``ty`` with the value of ``None`` should be treated as "any entity with matching ID". """ raise NotImplementedError diff --git a/telethon/sessions/memory.py b/telethon/sessions/memory.py index 67602ec9..1c86aff7 100644 --- a/telethon/sessions/memory.py +++ b/telethon/sessions/memory.py @@ -36,7 +36,7 @@ class MemorySession(Session): async def insert_entities(self, entities: List[Entity]): self.entities.update((e.id, (e.ty, e.access_hash)) for e in entities) - async def get_entity(self, ty: int, id: int) -> Optional[Entity]: + async def get_entity(self, ty: Optional[int], id: int) -> Optional[Entity]: try: ty, access_hash = self.entities[id] return Entity(ty, id, access_hash) diff --git a/telethon/sessions/sqlite.py b/telethon/sessions/sqlite.py index 5cd288aa..5bfc0433 100644 --- a/telethon/sessions/sqlite.py +++ b/telethon/sessions/sqlite.py @@ -55,13 +55,17 @@ class SQLiteSession(Session): self._upgrade_database(old=version) c.execute("delete from version") c.execute("insert into version values (?)", (CURRENT_VERSION,)) - self.save() + self._conn.commit() else: # Tables don't exist, create new ones + self._create_table(c, 'version (version integer primary key)') self._mk_tables(c) c.execute("insert into version values (?)", (CURRENT_VERSION,)) - c.close() - self.save() + self._conn.commit() + + # Must have committed or else the version will not have been updated while new tables + # exist, leading to a half-upgraded state. + c.close() def _upgrade_database(self, old): c = self._cursor() @@ -146,9 +150,6 @@ class SQLiteSession(Session): def _mk_tables(self, c): self._create_table( c, - '''version ( - version integer primary key - )''', '''datacenter ( id integer primary key, ip text not null, @@ -243,7 +244,7 @@ class SQLiteSession(Session): finally: c.close() - async def get_entity(self, ty: int, id: int) -> Optional[Entity]: + async def get_entity(self, ty: Optional[int], id: int) -> Optional[Entity]: row = self._execute('select ty, id, access_hash from entity where id = ?', id) return Entity(*row) if row else None