ml/nearest-search-train.ipynb
2023-09-10 08:43:58 +03:00

178 lines
9.5 KiB
Plaintext

{
"cells": [
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n",
"To disable this warning, you can either:\n",
"\t- Avoid using `tokenizers` before the fork if possible\n",
"\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n",
"Requirement already satisfied: annoy in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (1.17.3)\n",
"Requirement already satisfied: sentence_transformers in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (2.2.2)\n",
"Requirement already satisfied: transformers<5.0.0,>=4.6.0 in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from sentence_transformers) (4.32.1)\n",
"Requirement already satisfied: tqdm in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from sentence_transformers) (4.66.1)\n",
"Requirement already satisfied: torch>=1.6.0 in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from sentence_transformers) (2.0.1)\n",
"Requirement already satisfied: torchvision in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from sentence_transformers) (0.15.2)\n",
"Requirement already satisfied: numpy in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from sentence_transformers) (1.25.2)\n",
"Requirement already satisfied: scikit-learn in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from sentence_transformers) (1.3.0)\n",
"Requirement already satisfied: scipy in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from sentence_transformers) (1.11.2)\n",
"Requirement already satisfied: nltk in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from sentence_transformers) (3.8.1)\n",
"Requirement already satisfied: sentencepiece in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from sentence_transformers) (0.1.99)\n",
"Requirement already satisfied: huggingface-hub>=0.4.0 in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from sentence_transformers) (0.16.4)\n",
"Requirement already satisfied: filelock in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from huggingface-hub>=0.4.0->sentence_transformers) (3.12.3)\n",
"Requirement already satisfied: fsspec in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from huggingface-hub>=0.4.0->sentence_transformers) (2023.6.0)\n",
"Requirement already satisfied: requests in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from huggingface-hub>=0.4.0->sentence_transformers) (2.31.0)\n",
"Requirement already satisfied: pyyaml>=5.1 in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from huggingface-hub>=0.4.0->sentence_transformers) (6.0.1)\n",
"Requirement already satisfied: typing-extensions>=3.7.4.3 in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from huggingface-hub>=0.4.0->sentence_transformers) (4.7.1)\n",
"Requirement already satisfied: packaging>=20.9 in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from huggingface-hub>=0.4.0->sentence_transformers) (23.1)\n",
"Requirement already satisfied: sympy in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from torch>=1.6.0->sentence_transformers) (1.12)\n",
"Requirement already satisfied: networkx in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from torch>=1.6.0->sentence_transformers) (3.1)\n",
"Requirement already satisfied: jinja2 in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from torch>=1.6.0->sentence_transformers) (3.1.2)\n",
"Requirement already satisfied: regex!=2019.12.17 in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from transformers<5.0.0,>=4.6.0->sentence_transformers) (2023.8.8)\n",
"Requirement already satisfied: tokenizers!=0.11.3,<0.14,>=0.11.1 in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from transformers<5.0.0,>=4.6.0->sentence_transformers) (0.13.3)\n",
"Requirement already satisfied: safetensors>=0.3.1 in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from transformers<5.0.0,>=4.6.0->sentence_transformers) (0.3.3)\n",
"Requirement already satisfied: click in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from nltk->sentence_transformers) (8.1.7)\n",
"Requirement already satisfied: joblib in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from nltk->sentence_transformers) (1.3.2)\n",
"Requirement already satisfied: threadpoolctl>=2.0.0 in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from scikit-learn->sentence_transformers) (3.2.0)\n",
"Requirement already satisfied: pillow!=8.3.*,>=5.3.0 in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from torchvision->sentence_transformers) (10.0.0)\n",
"Requirement already satisfied: MarkupSafe>=2.0 in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from jinja2->torch>=1.6.0->sentence_transformers) (2.1.3)\n",
"Requirement already satisfied: charset-normalizer<4,>=2 in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from requests->huggingface-hub>=0.4.0->sentence_transformers) (3.2.0)\n",
"Requirement already satisfied: idna<4,>=2.5 in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from requests->huggingface-hub>=0.4.0->sentence_transformers) (3.4)\n",
"Requirement already satisfied: urllib3<3,>=1.21.1 in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from requests->huggingface-hub>=0.4.0->sentence_transformers) (2.0.4)\n",
"Requirement already satisfied: certifi>=2017.4.17 in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from requests->huggingface-hub>=0.4.0->sentence_transformers) (2023.7.22)\n",
"Requirement already satisfied: mpmath>=0.19 in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from sympy->torch>=1.6.0->sentence_transformers) (1.3.0)\n"
]
}
],
"source": [
"!pip3 install annoy sentence_transformers"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
" from .autonotebook import tqdm as notebook_tqdm\n"
]
}
],
"source": [
"from sentence_transformers import SentenceTransformer\n",
"model = SentenceTransformer('sentence-transformers/LaBSE')\n",
"\n",
"def get_embs(text):\n",
" embeddings = model.encode(text)\n",
" return embeddings"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"import pickle\n",
"with open('n_labels.pickle', 'rb') as file:\n",
" n_labels = pickle.load(file)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"with open('annoy_labels.pickle', 'rb') as file:\n",
" labels = pickle.load(file)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"embs = []\n",
"for text in n_labels:\n",
" embs.append(get_embs(text))\n",
" if len(embs) % 50 == 0:\n",
" print(len(embs))"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"from annoy import AnnoyIndex\n",
"index = AnnoyIndex(768, 'angular')"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"for i, emb in enumerate(embs):\n",
" index.add_item(i, emb)"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"True"
]
},
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"index.build(20)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.5"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}