diff --git a/README.rst b/README.rst index a2e0d3de..3dd6b9aa 100755 --- a/README.rst +++ b/README.rst @@ -41,6 +41,26 @@ Creating a client client.start() +Store sessions in Redis +----------------------- + +.. code:: python + + from telethon import TelegramClient + from telethon.sessions.redis import RedisSession + import redis + + # These example values won't work. You must get your own api_id and + # api_hash from https://my.telegram.org, under API Development. + api_id = 12345 + api_hash = '0123456789abcdef0123456789abcdef' + + redis_connector = redis.Redis(host='localhost', port=6379, db=0, decode_responses=False) + session = RedisSession('session_name', redis_connector) + client = TelegramClient(session, api_id, api_hash) + client.start() + + Doing stuff ----------- diff --git a/optional-requirements.txt b/optional-requirements.txt index fb83c1ab..15871a78 100644 --- a/optional-requirements.txt +++ b/optional-requirements.txt @@ -2,3 +2,4 @@ cryptg pysocks hachoir3 sqlalchemy +redis diff --git a/setup.py b/setup.py index 05ca9197..4fd005a6 100755 --- a/setup.py +++ b/setup.py @@ -79,7 +79,7 @@ def main(): # Try importing the telethon module to assert it has no errors try: import telethon - except: + except Exception: print('Packaging for PyPi aborted, importing the module failed.') return @@ -113,16 +113,18 @@ def main(): version = re.search(r"^__version__\s*=\s*'(.*)'.*$", f.read(), flags=re.MULTILINE).group(1) setup( - name='Telethon', + name='Telethon-ezdev128', version=version, - description="Full-featured Telegram client library for Python 3", + description="Full-featured Telegram client library for Python 3 (ezdev128's fork-and-merge)", long_description=long_description, - url='https://github.com/LonamiWebs/Telethon', - download_url='https://github.com/LonamiWebs/Telethon/releases', + url='https://github.com/ezdev128/Telethon', + download_url='https://github.com/ezdev128/Telethon/releases', - author='Lonami Exo', - author_email='totufals@hotmail.com', + author='Konstantin M.', + author_email='ezdev128@yandex.com', + maintainer='Konstantin M.', + maintainer_email='ezdev128@yandex.com', license='MIT', @@ -157,7 +159,8 @@ def main(): 'typing' if version_info < (3, 5) else ""], extras_require={ 'cryptg': ['cryptg'], - 'sqlalchemy': ['sqlalchemy'] + 'sqlalchemy': ['sqlalchemy'], + 'redis': ['redis'], } ) diff --git a/telethon/events/__init__.py b/telethon/events/__init__.py index 36020afb..956c98da 100644 --- a/telethon/events/__init__.py +++ b/telethon/events/__init__.py @@ -86,6 +86,7 @@ class _EventCommon(abc.ABC): and not broadcast ) self.is_channel = isinstance(chat_peer, types.PeerChannel) + self.is_supergroup = self.is_group and self.is_channel def _get_entity(self, msg_id, entity_id, chat=None): """ @@ -295,6 +296,9 @@ class NewMessage(_EventBuilder): is_group (:obj:`bool`): True if the message was sent on a group or megagroup. + is_supergroup (:obj:`bool`): + True if the message was sent on a supergroup. + is_channel (:obj:`bool`): True if the message was sent on a megagroup or channel. diff --git a/telethon/sessions/__init__.py b/telethon/sessions/__init__.py index a487a4bd..24e32484 100644 --- a/telethon/sessions/__init__.py +++ b/telethon/sessions/__init__.py @@ -2,3 +2,4 @@ from .abstract import Session from .memory import MemorySession from .sqlite import SQLiteSession from .sqlalchemy import AlchemySessionContainer, AlchemySession +from .redis import RedisSession diff --git a/telethon/sessions/redis.py b/telethon/sessions/redis.py new file mode 100644 index 00000000..bdd28776 --- /dev/null +++ b/telethon/sessions/redis.py @@ -0,0 +1,292 @@ + +from .memory import MemorySession, _SentFileType +from ..crypto import AuthKey +from .. import utils +from ..tl.types import ( + InputPhoto, InputDocument, PeerUser, PeerChat, PeerChannel +) +import logging +import json +import base64 +import time +import redis +import pickle + +TS_STR_FORMAT = "%F %T" +HIVE_PREFIX = "telethon:client" +PACK_FUNC = "json" +UNPACK_FUNC = "json" + + +__log__ = logging.getLogger(__name__) + + +class RedisSession(MemorySession): + + session_name = None + redis_connection = None + hive_prefix = None + sess_prefix = None + pack_func = None + unpack_func = None + use_indents = True + + def __init__(self, session_name=None, redis_connection=None, hive_prefix=None, + pack_func=PACK_FUNC, unpack_func=UNPACK_FUNC): + if not isinstance(session_name, (str, bytes)): + raise TypeError("Session name must be a string or bytes.") + + if not redis_connection or not isinstance(redis_connection, redis.StrictRedis): + raise TypeError('The given redis_connection must be a Redis instance.') + + super().__init__() + + self.session_name = session_name if isinstance(session_name, str) else session_name.decode() + self.redis_connection = redis_connection + + self.hive_prefix = hive_prefix or HIVE_PREFIX + self.pack_func = pack_func + self.unpack_func = unpack_func + + self.sess_prefix = "{}:{}".format(self.hive_prefix, self.session_name) + + self.save_entities = True + + self.feed_session() + + def _pack(self, o, **kwargs): + if self.pack_func == "json": + if self.use_indents: + kwargs["indent"] = 2 + return json.dumps(o, **kwargs) if self.pack_func == "json" else pickle.dumps(o, **kwargs) + + def _unpack(self, o, **kwargs): + return json.loads(o, **kwargs) if self.unpack_func == "json" else pickle.loads(o, **kwargs) + + def feed_session(self): + try: + s = self._get_sessions() + if len(s) == 0: + self._auth_key = AuthKey(data=bytes()) + return + + s = self.redis_connection.get(s[-1]) + if not s: + # No sessions + self._auth_key = AuthKey(data=bytes()) + return + + s = self._unpack(s) + self._dc_id = s["dc_id"] + self._server_address = s["server_address"] + self._port = s["port"] + auth_key = base64.standard_b64decode(s["auth_key"]) + self._auth_key = AuthKey(data=auth_key) + except Exception as ex: + __log__.exception(ex.args) + + def _update_sessions(self): + """ + Stores session into redis. + """ + auth_key = self._auth_key.key if self._auth_key else bytes() + if not self._dc_id: + return + + s = { + "dc_id": self._dc_id, + "server_address": self._server_address, + "port": self._port, + "auth_key": base64.standard_b64encode(auth_key).decode(), + "ts_ts": time.time(), + "ts_str": time.strftime(TS_STR_FORMAT, time.localtime()), + } + + key = "{}:sessions:{}".format(self.sess_prefix, self._dc_id) + try: + self.redis_connection.set(key, self._pack(s)) + except Exception as ex: + __log__.exception(ex.args) + + def set_dc(self, dc_id, server_address, port): + """ + Sets the information of the data center address and port that + the library should connect to, as well as the data center ID, + which is currently unused. + """ + super().set_dc(dc_id, server_address, port) + self._update_sessions() + + auth_key = bytes() + + if not self._dc_id: + self._auth_key = AuthKey(data=auth_key) + return + + key = "{}:sessions:{}".format(self.sess_prefix, self._dc_id) + s = self.redis_connection.get(key) + if s: + s = self._unpack(s) + auth_key = base64.standard_b64decode(s["auth_key"]) + self._auth_key = AuthKey(data=auth_key) + + @MemorySession.auth_key.setter + def auth_key(self, value): + """ + Sets the ``AuthKey`` to be used for the saved data center. + """ + self._auth_key = value + self._update_sessions() + + def list_sessions(self): + """ + Lists available sessions. Not used by the library itself. + """ + return self._get_sessions(strip_prefix=True) + + def process_entities(self, tlo): + """ + Processes the input ``TLObject`` or ``list`` and saves + whatever information is relevant (e.g., ID or access hash). + """ + + if not self.save_entities: + return + + rows = self._entities_to_rows(tlo) + if not rows or len(rows) == 0 or len(rows[0]) == 0: + return + + try: + rows = rows[0] + key = "{}:entities:{}".format(self.sess_prefix, rows[0]) + s = { + "id": rows[0], + "hash": rows[1], + "username": rows[2], + "phone": rows[3], + "name": rows[4], + "ts_ts": time.time(), + "ts_str": time.strftime(TS_STR_FORMAT, time.localtime()), + } + self.redis_connection.set(key, self._pack(s)) + except Exception as ex: + __log__.exception(ex.args) + + def _get_entities(self, strip_prefix=False): + """ + Returns list of entities. if strip_prefix is False - returns redis keys, + else returns list of id's + """ + key_pattern = "{}:{}:entities:".format(self.hive_prefix, self.session_name) + try: + entities = self.redis_connection.keys(key_pattern+"*") + if not strip_prefix: + return entities + return [s.decode().replace(key_pattern, "") for s in entities] + except Exception as ex: + __log__.exception(ex.args) + return [] + + def _get_sessions(self, strip_prefix=False): + """ + Returns list of sessions. if strip_prefix is False - returns redis keys, + else returns list of id's + """ + key_pattern = "{}:{}:sessions:".format(self.hive_prefix, self.session_name) + try: + sessions = self.redis_connection.keys(key_pattern+"*") + return [s.decode().replace(key_pattern, "") if strip_prefix else s.decode() for s in sessions] + except Exception as ex: + __log__.exception(ex.args) + return [] + + def get_entity_rows_by_phone(self, phone): + try: + for key in self._get_entities(): + entity = self._unpack(self.redis_connection.get(key)) + if "phone" in entity and entity["phone"] == phone: + return entity["id"], entity["hash"] + except Exception as ex: + __log__.exception(ex.args) + return None + + def get_entity_rows_by_username(self, username): + try: + for key in self._get_entities(): + entity = self._unpack(self.redis_connection.get(key)) + if "username" in entity and entity["username"] == username: + return entity["id"], entity["hash"] + except Exception as ex: + __log__.exception(ex.args) + return None + + def get_entity_rows_by_name(self, name): + try: + for key in self._get_entities(): + entity = self._unpack(self.redis_connection.get(key)) + if "name" in entity and entity["name"] == name: + return entity["id"], entity["hash"] + except Exception as ex: + __log__.exception(ex.args) + + return None + + def get_entity_rows_by_id(self, entity_id, exact=True): + if exact: + key = "{}:entities:{}".format(self.sess_prefix, entity_id) + s = self.redis_connection.get(key) + if not s: + return None + try: + s = self._unpack(s) + return entity_id, s["hash"] + except Exception as ex: + __log__.exception(ex.args) + return None + else: + ids = ( + utils.get_peer_id(PeerUser(entity_id)), + utils.get_peer_id(PeerChat(entity_id)), + utils.get_peer_id(PeerChannel(entity_id)) + ) + + try: + for key in self._get_entities(): + entity = self._unpack(self.redis_connection.get(key)) + if "id" in entity and entity["id"] in ids: + return entity["id"], entity["hash"] + except Exception as ex: + __log__.exception(ex.args) + + def get_file(self, md5_digest, file_size, cls): + key = "{}:sent_files:{}".format(self.sess_prefix, md5_digest) + s = self.redis_connection.get(key) + if s: + try: + s = self._unpack(s) + return md5_digest, file_size \ + if s["file_size"] == file_size and s["type"] == _SentFileType.from_type(cls).value \ + else None + except Exception as ex: + __log__.exception(ex.args) + return None + + def cache_file(self, md5_digest, file_size, instance): + if not isinstance(instance, (InputDocument, InputPhoto)): + raise TypeError('Cannot cache {} instance'.format(type(instance))) + + key = "{}:sent_files:{}".format(self.sess_prefix, md5_digest) + s = { + "md5_digest": md5_digest, + "file_size": file_size, + "type": _SentFileType.from_type(type(instance)).value, + "id": instance.id, + "hash": instance.access_hash, + "ts_ts": time.time(), + "ts_str": time.strftime(TS_STR_FORMAT, time.localtime()), + } + try: + self.redis_connection.set(key, self._pack(s)) + except Exception as ex: + __log__.exception(ex.args) diff --git a/telethon/telegram_bare_client.py b/telethon/telegram_bare_client.py index 7164bb17..42c14d98 100644 --- a/telethon/telegram_bare_client.py +++ b/telethon/telegram_bare_client.py @@ -73,7 +73,6 @@ class TelegramBareClient: update_workers=None, spawn_read_thread=False, timeout=timedelta(seconds=5), - loop=None, device_model=None, system_version=None, app_version=None, @@ -611,7 +610,7 @@ class TelegramBareClient: ) self._recv_thread.start() - def _signal_handler(self, signum, frame): + def _signal_handler(self, *, _): if self._user_connected: self.disconnect() else: @@ -673,7 +672,7 @@ class TelegramBareClient: # a ping) if we want to receive updates again. # TODO Test if getDifference works too (better alternative) self._sender.send(GetStateRequest()) - except: + except Exception: self._idling.clear() raise diff --git a/telethon/telegram_client.py b/telethon/telegram_client.py index 02bf918e..294df506 100644 --- a/telethon/telegram_client.py +++ b/telethon/telegram_client.py @@ -192,6 +192,9 @@ class TelegramClient(TelegramBareClient): # Sometimes we need to know who we are, cache the self peer self._self_input_peer = None + # Store get_me() after successful sign in + self.me = None + # endregion # region Telegram requests functions @@ -298,10 +301,12 @@ class TelegramClient(TelegramBareClient): if self.is_user_authorized(): self._check_events_pending_resolve() + self.me = self.get_me() return self if bot_token: self.sign_in(bot_token=bot_token) + self.me = self.get_me() return self # Turn the callable into a valid phone number @@ -355,6 +360,7 @@ class TelegramClient(TelegramBareClient): # We won't reach here if any step failed (exit by exception) print('Signed in successfully as', utils.get_display_name(me)) self._check_events_pending_resolve() + self.me = self.get_me() return self def sign_in(self, phone=None, code=None, @@ -1152,9 +1158,9 @@ class TelegramClient(TelegramBareClient): raise TypeError('Invalid message type: {}'.format(type(message))) def iter_participants(self, entity, limit=None, search='', - filter=None, aggressive=False, _total_box=None): + aggressive=False, _total_box=None): """ - Iterator over the participants belonging to the specified chat. + Gets the list of participants from the specified entity. Args: entity (:obj:`entity`): @@ -1166,12 +1172,6 @@ class TelegramClient(TelegramBareClient): search (:obj:`str`, optional): Look for participants with this string in name/username. - filter (:obj:`ChannelParticipantsFilter`, optional): - The filter to be used, if you want e.g. only admins. See - https://lonamiwebs.github.io/Telethon/types/channel_participants_filter.html. - Note that you might not have permissions for some filter. - This has no effect for normal chats or users. - aggressive (:obj:`bool`, optional): Aggressively looks for all participants in the chat in order to get more than 10,000 members (a hard limit @@ -1180,32 +1180,16 @@ class TelegramClient(TelegramBareClient): participants on groups with 100,000 members. This has no effect for groups or channels with less than - 10,000 members, or if a ``filter`` is given. + 10,000 members. _total_box (:obj:`_Box`, optional): A _Box instance to pass the total parameter by reference. - Yields: - The ``User`` objects returned by ``GetParticipantsRequest`` - with an additional ``.participant`` attribute which is the - matched ``ChannelParticipant`` type for channels/megagroups - or ``ChatParticipants`` for normal chats. + Returns: + A list of participants with an additional .total variable on the + list indicating the total amount of members in this group/channel. """ - if isinstance(filter, type): - filter = filter() - entity = self.get_input_entity(entity) - if search and (filter or not isinstance(entity, InputPeerChannel)): - # We need to 'search' ourselves unless we have a PeerChannel - search = search.lower() - - def filter_entity(ent): - return search in utils.get_display_name(ent).lower() or\ - search in (getattr(ent, 'username', '') or None).lower() - else: - def filter_entity(ent): - return True - limit = float('inf') if limit is None else int(limit) if isinstance(entity, InputPeerChannel): total = self(GetFullChannelRequest( @@ -1218,7 +1202,7 @@ class TelegramClient(TelegramBareClient): return seen = set() - if total > 10000 and aggressive and not filter: + if total > 10000 and aggressive: requests = [GetParticipantsRequest( channel=entity, filter=ChannelParticipantsSearch(search + chr(x)), @@ -1229,7 +1213,7 @@ class TelegramClient(TelegramBareClient): else: requests = [GetParticipantsRequest( channel=entity, - filter=filter or ChannelParticipantsSearch(search), + filter=ChannelParticipantsSearch(search), offset=0, limit=200, hash=0 @@ -1255,47 +1239,31 @@ class TelegramClient(TelegramBareClient): if not participants.users: requests.pop(i) else: - requests[i].offset += len(participants.participants) - users = {user.id: user for user in participants.users} - for participant in participants.participants: - user = users[participant.user_id] - if not filter_entity(user) or user.id in seen: - continue - - seen.add(participant.user_id) - user = users[participant.user_id] - user.participant = participant - yield user - if len(seen) >= limit: - return + requests[i].offset += len(participants.users) + for user in participants.users: + if user.id not in seen: + seen.add(user.id) + yield user + if len(seen) >= limit: + return elif isinstance(entity, InputPeerChat): - # TODO We *could* apply the `filter` here ourselves - full = self(GetFullChatRequest(entity.chat_id)) + users = self(GetFullChatRequest(entity.chat_id)).users if _total_box: - _total_box.x = len(full.full_chat.participants.participants) + _total_box.x = len(users) have = 0 - users = {user.id: user for user in full.users} - for participant in full.full_chat.participants.participants: - user = users[participant.user_id] - if not filter_entity(user): - continue + for user in users: have += 1 if have > limit: break else: - user = users[participant.user_id] - user.participant = participant yield user else: if _total_box: _total_box.x = 1 if limit != 0: - user = self.get_entity(entity) - if filter_entity(user): - user.participant = None - yield user + yield self.get_entity(entity) def get_participants(self, *args, **kwargs): """ @@ -1304,9 +1272,9 @@ class TelegramClient(TelegramBareClient): """ total_box = _Box(0) kwargs['_total_box'] = total_box - participants = UserList(self.iter_participants(*args, **kwargs)) - participants.total = total_box.x - return participants + dialogs = UserList(self.iter_participants(*args, **kwargs)) + dialogs.total = total_box.x + return dialogs # endregion @@ -2340,8 +2308,7 @@ class TelegramClient(TelegramBareClient): return self.get_me() result = self(ResolveUsernameRequest(username)) for entity in itertools.chain(result.users, result.chats): - if getattr(entity, 'username', None) or ''\ - .lower() == username: + if entity.username.lower() == username: return entity try: # Nobody with this username, maybe it's an exact name/title