upload ml

This commit is contained in:
илья 2023-09-10 08:43:58 +03:00
commit e548bd4ad7
9 changed files with 1495 additions and 0 deletions

BIN
.DS_Store vendored Normal file

Binary file not shown.

10
README.md Normal file
View File

@ -0,0 +1,10 @@
### Установить зависимости
`pip3 install -r req.txt`
### Локации документов для обучения
- `catboost-train.ipynb` - файл для обучения кэтбуста
- `bert_train.ipynb` - файл для обучения берта
- `nearest-search-train.ipynb` - файл для обучения поиска близжайших соседей
- `tfidf-train.ipynb` - файл для обучения tf-idf + random forest
### Инференс модели
`uvicorn inference:app --reload --workers 1`
### У мля слишком много больших файлов поэтому мы выложили код с весами моделей на гугл диск https://drive.google.com/drive/folders/1hnWKpZjtQLBbzAE9YsUW_4x-IEb3mFvg?usp=sharing

BIN
annoy_labels.pickle Normal file

Binary file not shown.

287
bert_train.ipynb Normal file
View File

@ -0,0 +1,287 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n",
"To disable this warning, you can either:\n",
"\t- Avoid using `tokenizers` before the fork if possible\n",
"\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n",
"Requirement already satisfied: transformers in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (4.32.1)\n",
"Requirement already satisfied: datasets in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (2.14.4)\n",
"Requirement already satisfied: pandas in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (2.1.0)\n",
"Requirement already satisfied: evaluate in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (0.4.0)\n",
"Requirement already satisfied: numpy in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (1.25.2)\n",
"Requirement already satisfied: filelock in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from transformers) (3.12.3)\n",
"Requirement already satisfied: huggingface-hub<1.0,>=0.15.1 in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from transformers) (0.16.4)\n",
"Requirement already satisfied: packaging>=20.0 in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from transformers) (23.1)\n",
"Requirement already satisfied: pyyaml>=5.1 in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from transformers) (6.0.1)\n",
"Requirement already satisfied: regex!=2019.12.17 in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from transformers) (2023.8.8)\n",
"Requirement already satisfied: requests in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from transformers) (2.31.0)\n",
"Requirement already satisfied: tokenizers!=0.11.3,<0.14,>=0.11.1 in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from transformers) (0.13.3)\n",
"Requirement already satisfied: safetensors>=0.3.1 in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from transformers) (0.3.3)\n",
"Requirement already satisfied: tqdm>=4.27 in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from transformers) (4.66.1)\n",
"Requirement already satisfied: pyarrow>=8.0.0 in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from datasets) (13.0.0)\n",
"Requirement already satisfied: dill<0.3.8,>=0.3.0 in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from datasets) (0.3.7)\n",
"Requirement already satisfied: xxhash in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from datasets) (3.3.0)\n",
"Requirement already satisfied: multiprocess in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from datasets) (0.70.15)\n",
"Requirement already satisfied: fsspec[http]>=2021.11.1 in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from datasets) (2023.6.0)\n",
"Requirement already satisfied: aiohttp in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from datasets) (3.8.5)\n",
"Requirement already satisfied: python-dateutil>=2.8.2 in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from pandas) (2.8.2)\n",
"Requirement already satisfied: pytz>=2020.1 in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from pandas) (2023.3)\n",
"Requirement already satisfied: tzdata>=2022.1 in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from pandas) (2023.3)\n",
"Requirement already satisfied: responses<0.19 in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from evaluate) (0.18.0)\n",
"Requirement already satisfied: attrs>=17.3.0 in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from aiohttp->datasets) (23.1.0)\n",
"Requirement already satisfied: charset-normalizer<4.0,>=2.0 in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from aiohttp->datasets) (3.2.0)\n",
"Requirement already satisfied: multidict<7.0,>=4.5 in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from aiohttp->datasets) (6.0.4)\n",
"Requirement already satisfied: async-timeout<5.0,>=4.0.0a3 in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from aiohttp->datasets) (4.0.3)\n",
"Requirement already satisfied: yarl<2.0,>=1.0 in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from aiohttp->datasets) (1.9.2)\n",
"Requirement already satisfied: frozenlist>=1.1.1 in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from aiohttp->datasets) (1.4.0)\n",
"Requirement already satisfied: aiosignal>=1.1.2 in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from aiohttp->datasets) (1.3.1)\n",
"Requirement already satisfied: typing-extensions>=3.7.4.3 in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from huggingface-hub<1.0,>=0.15.1->transformers) (4.7.1)\n",
"Requirement already satisfied: six>=1.5 in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from python-dateutil>=2.8.2->pandas) (1.16.0)\n",
"Requirement already satisfied: idna<4,>=2.5 in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from requests->transformers) (3.4)\n",
"Requirement already satisfied: urllib3<3,>=1.21.1 in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from requests->transformers) (2.0.4)\n",
"Requirement already satisfied: certifi>=2017.4.17 in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from requests->transformers) (2023.7.22)\n"
]
}
],
"source": [
"!pip3 install transformers datasets pandas evaluate numpy"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"label2id = {\n",
" 'AAA(RU)': 0,\n",
" 'AA(RU)': 1, \n",
" 'A+(RU)': 2,\n",
" 'A(RU)': 3,\n",
" 'A-(RU)': 4,\n",
" 'BBB+(RU)': 5,\n",
" 'BBB(RU)': 6, \n",
" 'AA+(RU)': 7,\n",
" 'BBB-(RU)': 8,\n",
" 'AA-(RU)': 9,\n",
" 'BB+(RU)': 10, \n",
" 'BB-(RU)': 11, \n",
" 'B+(RU)': 12,\n",
" 'BB(RU)': 13, \n",
" 'B(RU)': 14,\n",
" 'B-(RU)': 15, \n",
" 'C(RU)': 16\n",
"}\n",
"id2label = {0: 'AAA(RU)',\n",
" 1: 'AA(RU)',\n",
" 2: 'A+(RU)',\n",
" 3: 'A(RU)',\n",
" 4: 'A-(RU)',\n",
" 5: 'BBB+(RU)',\n",
" 6: 'BBB(RU)',\n",
" 7: 'AA+(RU)',\n",
" 8: 'BBB-(RU)',\n",
" 9: 'AA-(RU)',\n",
" 10: 'BB+(RU)',\n",
" 11: 'BB-(RU)',\n",
" 12: 'B+(RU)',\n",
" 13: 'BB(RU)',\n",
" 14: 'B(RU)',\n",
" 15: 'B-(RU)',\n",
" 16: 'C(RU)'}"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"import pickle\n",
"with open('dss.pickle', 'rb') as file:\n",
" data = pickle.load(file)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'text': 'Кредитный рейтинг АКБ Энергобанк ПАО далее Энергобанк Банк обусловлен удовлетворительным бизнеспрофилем в сочетании с сильной достаточностью капитала критическим рискпрофилем и удовлетворительной оценкой ликвидности и фондирования учитывающей концентрацию обязательств Банка на средствах крупнейших кредиторовОсновная деятельность Энергобанка сконцентрирована в Республике Татарстан далее РТ где он занимает устойчивые рыночные позиции занял шестое место по размеру активов на На российском банковском рынке Банк имеет невысокую долю на он занимал е место по величине собственного капитала и е место по размеру активовКлючевыми направлениями деятельности Банка являются кредитование предприятий агропромышленного комплекса строительства и торговли а также залоговое розничное кредитование Контролирующими собственниками Банка являются ИН Хайруллин и АН Хайруллин владеющие около акций через АО Эдельвейс КорпорейшнОценка бизнеспрофиля Банка отражает его относительно невысокую долю на российском рынке банковских услуг и выраженную региональную направленность его деятельности несмотря на планы по открытию отделений за пределами РТОперационный доход Банка характеризуется низким хотя и возрастающим уровнем диверсификации по итогам года значение индекса ХерфиндаляХиршмана составило и свидетельствовало о повышенной концентрации на кредитовании корпоративного сектора около операционного дохода Качество управления Банком оценивается АКРА как удовлетворительное и соответствующее среднему уровню в российском банковском секторе в целом Организационная структура Энергобанка соответствует масштабам и особенностям его бизнеса Структура собственности Банка прозрачна при этом отмечается связанность его операций с аффилированными с собственниками Банка компаниямиСтратегия Банка на период до конца года предполагает планомерное наращивание размера активов за',\n",
" 'label': 11}"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"data[0]"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
" from .autonotebook import tqdm as notebook_tqdm\n",
"Some weights of BertForSequenceClassification were not initialized from the model checkpoint at cointegrated/rubert-tiny and are newly initialized: ['classifier.weight', 'classifier.bias']\n",
"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
]
}
],
"source": [
"from transformers import AutoModelForSequenceClassification\n",
"model = AutoModelForSequenceClassification.from_pretrained(\n",
" \"cointegrated/rubert-tiny\", num_labels=len(id2label.keys()), id2label=id2label, label2id=label2id\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"from transformers import AutoTokenizer\n",
"tokenizer = AutoTokenizer.from_pretrained(\"cointegrated/rubert-tiny\")\n",
"def token(text):\n",
" return tokenizer(text['text'], padding=True, truncation=True, max_length=512, return_tensors='pt')"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Map: 100%|██████████| 11500/11500 [00:04<00:00, 2842.10 examples/s]\n",
"Map: 100%|██████████| 1845/1845 [00:00<00:00, 3125.02 examples/s]\n"
]
}
],
"source": [
"from datasets import Dataset\n",
"import pandas as pd\n",
"from random import shuffle\n",
"shuffle(data)\n",
"train = data[:11500]\n",
"test = data[11500:]\n",
"train = Dataset.from_pandas(pd.DataFrame(data=train))\n",
"test = Dataset.from_pandas(pd.DataFrame(data=test))\n",
"tokenized_train = train.map(token, batched=True)\n",
"tokenized_test = test.map(token, batched=True)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"import evaluate\n",
"import numpy as np\n",
"\n",
"f1 = evaluate.load(\"f1\")\n",
"\n",
"def compute_metrics(eval_pred):\n",
" predictions, labels = eval_pred\n",
" predictions = np.argmax(predictions, axis=1)\n",
" return f1.compute(predictions=predictions, references=labels, average='macro')"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Found safetensors installation, but --save_safetensors=False. Safetensors should be a preferred weights saving format due to security and performance reasons. If your model cannot be saved by safetensors please feel free to open an issue at https://github.com/huggingface/safetensors!\n",
"PyTorch: setting up devices\n",
"The default value for the training argument `--report_to` will change in v5 (from all installed integrations to none). In v5, you will need to use `--report_to all` to get the same behavior as now. You should start updating your code and make this info disappear :-).\n"
]
}
],
"source": [
"from transformers import TrainingArguments, Trainer, DataCollatorWithPadding\n",
"\n",
"data_collator = DataCollatorWithPadding(tokenizer=tokenizer)\n",
"training_args = TrainingArguments(\n",
" output_dir=\"akra_model\",\n",
" learning_rate=2e-5,\n",
" per_device_train_batch_size=16,\n",
" per_device_eval_batch_size=16,\n",
" num_train_epochs=10,\n",
" weight_decay=0.01,\n",
" evaluation_strategy=\"epoch\",\n",
" save_strategy=\"epoch\",\n",
" load_best_model_at_end=True,\n",
")\n",
"trainer = Trainer(\n",
" model=model,\n",
" args=training_args,\n",
" train_dataset=tokenized_train,\n",
" eval_dataset=tokenized_test,\n",
" tokenizer=tokenizer,\n",
" data_collator=data_collator,\n",
" compute_metrics=compute_metrics,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"trainer.train()"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.5"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}

161
catboost-train.ipynb Normal file
View File

@ -0,0 +1,161 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Requirement already satisfied: catboost in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (1.2.1)\n",
"Requirement already satisfied: graphviz in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from catboost) (0.20.1)\n",
"Requirement already satisfied: matplotlib in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from catboost) (3.7.2)\n",
"Requirement already satisfied: numpy>=1.16.0 in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from catboost) (1.25.2)\n",
"Requirement already satisfied: pandas>=0.24 in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from catboost) (2.1.0)\n",
"Requirement already satisfied: scipy in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from catboost) (1.11.2)\n",
"Requirement already satisfied: plotly in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from catboost) (5.16.1)\n",
"Requirement already satisfied: six in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from catboost) (1.16.0)\n",
"Requirement already satisfied: python-dateutil>=2.8.2 in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from pandas>=0.24->catboost) (2.8.2)\n",
"Requirement already satisfied: pytz>=2020.1 in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from pandas>=0.24->catboost) (2023.3)\n",
"Requirement already satisfied: tzdata>=2022.1 in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from pandas>=0.24->catboost) (2023.3)\n",
"Requirement already satisfied: contourpy>=1.0.1 in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from matplotlib->catboost) (1.1.0)\n",
"Requirement already satisfied: cycler>=0.10 in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from matplotlib->catboost) (0.11.0)\n",
"Requirement already satisfied: fonttools>=4.22.0 in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from matplotlib->catboost) (4.42.1)\n",
"Requirement already satisfied: kiwisolver>=1.0.1 in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from matplotlib->catboost) (1.4.5)\n",
"Requirement already satisfied: packaging>=20.0 in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from matplotlib->catboost) (23.1)\n",
"Requirement already satisfied: pillow>=6.2.0 in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from matplotlib->catboost) (10.0.0)\n",
"Requirement already satisfied: pyparsing<3.1,>=2.3.1 in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from matplotlib->catboost) (3.0.9)\n",
"Requirement already satisfied: tenacity>=6.2.0 in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from plotly->catboost) (8.2.3)\n"
]
}
],
"source": [
"!pip3 install catboost"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import pickle\n",
"\n",
"with open('cb.pickle', 'rb') as file:\n",
" cb_dataset = pickle.load(file)"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"from catboost import CatBoostClassifier"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"model = CatBoostClassifier(iterations=50000,\n",
" learning_rate=1e-2,\n",
" depth=8)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"label2id = {\n",
" 'AAA(RU)': 0,\n",
" 'AA(RU)': 1, \n",
" 'A+(RU)': 2,\n",
" 'A(RU)': 3,\n",
" 'A-(RU)': 4,\n",
" 'BBB+(RU)': 5,\n",
" 'BBB(RU)': 6, \n",
" 'AA+(RU)': 7,\n",
" 'BBB-(RU)': 8,\n",
" 'AA-(RU)': 9,\n",
" 'BB+(RU)': 10, \n",
" 'BB-(RU)': 11, \n",
" 'B+(RU)': 12,\n",
" 'BB(RU)': 13, \n",
" 'B(RU)': 14,\n",
" 'B-(RU)': 15, \n",
" 'C(RU)': 16\n",
"}"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"train = []\n",
"train_labels = []\n",
"for i in cb_dataset[0:6500]:\n",
" train.append([label2id[i['outs'][0]['answer']], i['outs'][0]['metric'], i['outs'][1]['metric'], label2id[i['outs'][1]['answer']], i['outs'][2]['metric'], label2id[i['outs'][2]['answer']]])\n",
" if not isinstance(i['label'], int):\n",
" train_labels.append(label2id[i['label'] + '(RU)'])\n",
" else:\n",
" train_labels.append(i['label'])"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"test = []\n",
"test_labels = []\n",
"for i in cb_dataset[6500:]:\n",
" test.append([label2id[i['outs'][0]['answer']], i['outs'][0]['metric'], i['outs'][1]['metric'], label2id[i['outs'][1]['answer']], i['outs'][2]['metric'], label2id[i['outs'][2]['answer']]])\n",
" if not isinstance(i['label'], int):\n",
" test_labels.append(label2id[i['label'] + '(RU)'])\n",
" else:\n",
" test_labels.append(i['label'])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"model.fit(train, train_labels)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.5"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}

529
inference.py Normal file
View File

@ -0,0 +1,529 @@
import torch
from transformers import AutoModelForSequenceClassification
from transformers import AutoTokenizer, BertConfig
from captum.attr import LayerIntegratedGradients
import re
import torch
import numpy as np
from collections import Counter
from fastapi import FastAPI
from pydantic import BaseModel
import pickle
from matplotlib.colors import LinearSegmentedColormap
from catboost import CatBoostClassifier
from pymystem3 import Mystem
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.ensemble import RandomForestClassifier
from matplotlib.colors import LinearSegmentedColormap
from sentence_transformers import SentenceTransformer
sentence_model = SentenceTransformer('sentence-transformers/LaBSE')
catboost = CatBoostClassifier().load_model('catboost')
def get_embs(text):
embeddings = sentence_model.encode(text)
return embeddings
cmap = LinearSegmentedColormap.from_list('rg',["w", "g"], N=512)
mstm = Mystem()
with open('vectorizer.pickle', 'rb') as file:
model_tfidf = pickle.load(file)
with open('tree.pickle', 'rb') as file:
cls = pickle.load(file)
def resolve_text(tokens, text):
words = text.split()
tokens_values = list(map(lambda tok: tok[0], tokens))
tokens_metrics = list(map(lambda tok: tok[1], tokens))
resolved = []
for i, word in enumerate(words):
try:
if mstm.lemmatize(word)[0] in tokens_values:
try:
value = tokens_metrics[tokens_values.index(mstm.lemmatize(word)[0])]
#color = from_abs_to_rgb(min(tokens_metrics), max(tokens_metrics), value)
resolved.append(f'<span data-value="{(value - min(tokens_metrics))/ max(tokens_metrics)}">{word}</span>')
except:
resolved.append(word)
else:
resolved.append(word)
except:
resolved.append(word)
return ' '.join(resolved)
def process_classify(text):
if not len(text.replace(' ', '')): return {'ans': 0, 'text': ''}
try:
normalized = ''.join(mstm.lemmatize(text)[:-1])
except: return {'ans': 0, 'text': ''}
tf_idfed = model_tfidf.transform(np.array([normalized]))[0]
ans = cls.predict(tf_idfed)[0]
return {'ans': ans, 'text': ""}
def process_embedding(text):
if not len(text.replace(' ', '')): return {'ans': 0, 'text': ''}
try:
normalized = ''.join(mstm.lemmatize(text)[:-1])
except: return {'ans': 0, 'text': ''}
tf_idfed = model_tfidf.transform(np.array([normalized]))[0]
values = []
for i in range(5000):
values.append(tf_idfed.todense()[0, i])
important_tokens = []
for i, val in enumerate(values):
if val > (np.min(values) + np.max(values)) / 3:
important_tokens.append((val, i))
tokens = model_tfidf.get_feature_names_out()
tokens = list(map(lambda x: (tokens[x[1]], x[0]), reversed(sorted(important_tokens))))
ans = cls.predict(tf_idfed)[0]
text = resolve_text(tokens, text)
return {'ans': ans, 'text': text}
cmap = LinearSegmentedColormap.from_list('rg',["w", "g"], N=512)
label2id = {
'AAA(RU)': 0,
'AA(RU)': 1,
'A+(RU)': 2,
'A(RU)': 3,
'A-(RU)': 4,
'BBB+(RU)': 5,
'BBB(RU)': 6,
'AA+(RU)': 7,
'BBB-(RU)': 8,
'AA-(RU)': 9,
'BB+(RU)': 10,
'BB-(RU)': 11,
'B+(RU)': 12,
'BB(RU)': 13,
'B(RU)': 14,
'B-(RU)': 15,
'C(RU)': 16
}
id2label = {0: 'AAA(RU)',
1: 'AA(RU)',
2: 'A+(RU)',
3: 'A(RU)',
4: 'A-(RU)',
5: 'BBB+(RU)',
6: 'BBB(RU)',
7: 'AA+(RU)',
8: 'BBB-(RU)',
9: 'AA-(RU)',
10: 'BB+(RU)',
11: 'BB-(RU)',
12: 'B+(RU)',
13: 'BB(RU)',
14: 'B(RU)',
15: 'B-(RU)',
16: 'C(RU)'}
cmap = LinearSegmentedColormap.from_list('rg',["w", "g"], N=512)
from math import inf
from annoy import AnnoyIndex
import numpy as np
import pickle
def get_distance(emb1, emb2):
emb2 /= np.sum(emb2**2)
emb1 /= np.sum(emb1**2)
return 1 / abs(np.dot(emb2-emb1, emb1-emb2))
with open('new_embeddings.pickle', 'rb') as file:
new_embeddings = pickle.load(file)
with open('annoy_labels.pickle', 'rb') as file:
labels = pickle.load(file)
with open('n_labels.pickle', 'rb') as file:
n_labels = pickle.load(file)
index = AnnoyIndex(768, 'angular')
index.load('nearest.annoy')
def get_nearest_value(embeddings):
items = list(map(lambda x: (
labels[x],
get_distance(embeddings, new_embeddings[x]),
list(n_labels)[x]
),
index.get_nns_by_vector(embeddings, 20)
))
weights = np.array([0 for _ in range(17)])
refs = [[] for _ in range(17)]
s = 0
for item in items:
if item[1] == inf:
return id2label[item[0]], 100, [item[2]]
s += item[1]
weights[item[0]] += item[1]
refs[item[0]].append(item[2])
return id2label[np.argmax(weights)], (weights[np.argmax(weights)] / s) * 100, refs[np.argmax(weights)]
def to_rgb(vals):
return f'rgb({int(vals[0]*255)}, {int(vals[1]*255)}, {int(vals[2]*255)})'
def from_abs_to_rgb(min, max, value):
return to_rgb(cmap((value - min)/ max))
def get_nns_tokens(encoding, attrs, predicted_id):
current_array = map(
lambda x: (tokenizer.convert_ids_to_tokens(encoding['input_ids'][0][x[0]-5:x[0]+5]), x[1]),
list(
reversed(
sorted(
enumerate(
attrs[0][predicted_id].numpy()
),
key=lambda x: x[1]
)
)
)[0:10]
)
return list(current_array)
def get_description_interpreting(attrs, predicted_id):
attrs = attrs.detach().numpy()
positive_weights = attrs[0][predicted_id]
negative_weights = [0 for _ in range(len(positive_weights))]
for i in range(len(attrs[0])):
if i == predicted_id: continue
for j in range(len(attrs[0][i])):
negative_weights[j] += attrs[0][i][j]
for i in range(len(negative_weights)):
negative_weights[i] /= len(attrs[0]) - 1
return {
'positive_weights': (
positive_weights,
{
'min': np.min(positive_weights),
'max': np.max(positive_weights)
}
),
'negative_weights': (
negative_weights,
{
'min': min(negative_weights),
'max': max(negative_weights)
}
)
}
def transform_token_ids(func_data, token_ids, word):
tokens = list(map(lambda x: tokenizer.convert_ids_to_tokens([x])[0].replace('##', ''), token({'text': clean(word)})['input_ids'][0]))
weights = [func_data['positive_weights'][0][i] for i in token_ids]
wts = []
for i in range(len(weights)):
if weights[i] > 0:
#color = from_abs_to_rgb(func_data['positive_weights'][1]['min'], func_data['positive_weights'][1]['max'], weights[i])
mn = max(func_data['positive_weights'][1]['min'], 0)
mx = func_data['positive_weights'][1]['max']
wts.append((weights[i] - mn)/ mx)
#word = word.lower().replace(tokens[i], f'<span data-value="{(weights[i] - mn)/ mx}">{tokens[i]}</span>')
try:
if sum(wts) / len(wts) >= 0.2:
return f'<span data-value={sum(wts) / len(wts)}>{word}</span>'
except: pass
return word
def build_text(tokens, func_data, current_text):
splitted_text = current_text.split()
splitted_text_iterator = 0
current_word = ''
current_word_ids = []
for i, token in enumerate(tokens):
decoded = tokenizer.convert_ids_to_tokens([token])[0]
if decoded == '[CLS]': continue
if not len(current_word):
current_word = decoded
current_word_ids.append(i)
elif decoded.startswith('##'):
current_word += decoded[2:]
current_word_ids.append(i)
else:
while clean(splitted_text[splitted_text_iterator]) != current_word:
splitted_text_iterator += 1
current_word = decoded
splitted_text[splitted_text_iterator] = transform_token_ids(func_data, current_word_ids, splitted_text[splitted_text_iterator])
current_word_ids = []
return ' '.join(splitted_text)
def squad_pos_forward_func(inputs, token_type_ids=None, attention_mask=None, position=0):
pred = predict(inputs.to(torch.long), token_type_ids.to(torch.long), attention_mask.to(torch.long))
pred = pred[position]
return pred.max(1).values
def predict_press_release(input_ids, token_type_ids, attention_mask):
encoding = {
'input_ids': input_ids.to(model.device),
'token_type_ids': token_type_ids.to(model.device),
'attention_mask': attention_mask.to(model.device)
}
outputs = model(**encoding)
return outputs
def clean(text):
text = re.sub('[^а-яё ]', ' ', str(text).lower())
text = re.sub(r" +", " ", text).strip()
return text
def get_description_interpreting(attrs):
positive_weights = attrs
return {
'positive_weights': (
positive_weights,
{
'min': np.min(positive_weights),
'max': np.max(positive_weights)
}
),
}
def predict(input_ids, token_type_ids, attention_mask):
encoding = {
'input_ids': input_ids.to(model.device),
'token_type_ids': token_type_ids.to(model.device),
'attention_mask': attention_mask.to(model.device)
}
outputs = model(**encoding)
return outputs
def batch_tokenize(text):
splitted_text = text.split()
current_batch = splitted_text[0]
batches = []
for word in splitted_text[1:]:
if len(tokenizer(current_batch + ' ' + word)['input_ids']) < 512:
current_batch += ' ' + word
else:
batches.append({
'text': current_batch
})
current_batch = word
return batches + [{'text': current_batch}]
def token(text):
return tokenizer(text['text'], padding=True, truncation=True, max_length=512, return_tensors='pt')
def tfidf_classify(data):
if not len(data.data): return ''
data = list(map(lambda x: x['text'], batch_tokenize(clean(data.data))))
predicted_labels = []
predicted_text = ""
for item in data:
predicted_labels.append(process_classify(item)['ans'])
ans = Counter(predicted_labels).most_common()[0][0]
score = len(list(filter(lambda x: x == ans, predicted_labels))) / len(predicted_labels)
ans = id2label[ans]
return {'answer': ans, 'text': predicted_text, 'metric': score, 'extendingLabels': list(map(lambda x: id2label[x], predicted_labels))}
def tfidf_embeddings(data):
if not len(data.data): return ''
data = list(map(lambda x: x['text'], batch_tokenize(clean(data.data))))
predicted_labels = []
predicted_text = ""
for item in data:
ans = process_embedding(item)
predicted_labels.append(ans['ans'])
predicted_text += ans['text'] + ' '
ans = Counter(predicted_labels).most_common()[0][0]
print(ans, predicted_text)
return {'answer': id2label[ans], 'text': predicted_text}
def bert_classify(data):
data = clean(data)
predicted = []
text = ''
batched = batch_tokenize(data)
for b in batched:
print(len(predicted))
embs = token(b)
answer = predict_press_release(
embs['input_ids'], embs['token_type_ids'], embs['attention_mask']
).logits[0]
answer = torch.softmax(answer, dim=-1).detach().numpy()
answer_score = np.max(answer)
predicted.append(
[id2label[np.argmax(answer)],
float(answer_score)]
)
ans = {'AA(RU)': [0]}
for i in predicted:
if i[0] not in ans.keys():
ans.update({i[0]: [i[1]]})
else:
ans[i[0]].append(i[1])
selected = 'AA(RU)'
score = 0
for candidate in ans.keys():
if sum(ans[candidate]) / len(ans[candidate]) > score:
score = sum(ans[candidate]) / len(ans[candidate])
selected = candidate
elif sum(ans[candidate]) / len(ans[candidate]) == score and len(ans[candidate]) > len(ans):
selected = candidate
return {
'answer': selected,
'text': text,
'longAnswer': predicted,
'metric': score
}
def bert_embeddings(data):
data = clean(data)
predicted = []
text = ''
batched = batch_tokenize(data)
for b in batched:
embs = token(b)
predicted.append(np.argmax(predict_press_release(embs['input_ids'], embs['token_type_ids'], embs['attention_mask']).logits.detach().numpy()[0]))
attrs = lig.attribute(embs['input_ids'], additional_forward_args=(embs['attention_mask'], embs['token_type_ids'], 0))
attrs = np.array(list(map(lambda x: x.sum(), attrs[0])))
descr = get_description_interpreting(attrs)
text += build_text(embs['input_ids'][0], descr, b['text']) + ' '
return {'answer': id2label[Counter(predicted).most_common()[0][0]], 'text': text}
config = BertConfig.from_json_file("./akra_model/checkpoint/config.json")
model = AutoModelForSequenceClassification.from_pretrained(
"./akra_model/checkpoint", config=config
)
tokenizer = AutoTokenizer.from_pretrained("cointegrated/rubert-tiny")
lig = LayerIntegratedGradients(squad_pos_forward_func, model.bert.embeddings)
app = FastAPI()
class Predict(BaseModel):
data: str
class ListPredict(BaseModel):
data: list
@app.post('/predict')
def predict_(data: Predict):
return bert_classify(data)
@app.post('/bert/process')
def predict_f(data: Predict):
return bert_classify(data)
@app.get('/interpret')
def interpret():
pass
@app.post('/tfidf/process')
def tfidf_res(data: Predict):
return tfidf_classify(data)
@app.post('/tfidf/batch')
def tfidf_batch(data: ListPredict):
res = []
for item in data.data:
res.append(tfidf_classify(Predict(data=item)))
return res
@app.post('/bert/batch')
def bert_batch(data: ListPredict):
res = []
for item in data:
res.append(bert_classify({'data': item}))
return res
@app.post('/bert/describe')
def bert_describe(data: Predict):
return bert_embeddings(data)
@app.post('/tfidf/describe')
def tfidf_describe(data: Predict):
return tfidf_embeddings(data)
def get_nearest_service(data: Predict):
data = clean(data.data)
batched = batch_tokenize(data)
res = []
scores = {}
for key in id2label.values():
scores.update({key: []})
for batch in batched:
features = list(get_nearest_value(get_embs(batch['text'])))
features[0] = features[0]
features[1] /= 100
scores[features[0]].append(features[1] if features[1] < 95 else 100)
res.append(
{
'text': batch['text'],
'features': features
}
)
mx = 0
label = 'AA(RU)'
for key in scores.keys():
try:
if (sum(scores[key]) / len(scores[key])) > mx:
label = key
mx = (sum(scores[key]) / len(scores[key]))
if (sum(scores[key]) / len(scores[key])) == mx:
if len(scores[key]) > len(scores[label]):
label = key
except: pass
return {'detailed': res, 'metric': mx, 'answer': label}
@app.post('/nearest/nearest')
def proccess_text(data: Predict):
return get_nearest_service(data)
@app.post('/catboost')
def catboost_process(data: Predict):
tfidf = tfidf_classify(data)
bert = bert_classify(data)
nearest = get_nearest_service(data)
inputs = [label2id[tfidf['answer']], tfidf['metric'], bert['metric'], label2id[bert['answer']], nearest['metric'], label2id[nearest['answer']]]
catboost_answer = id2label[catboost.predict([inputs])[0][0]]
return {
'bert': bert,
'tfidf': tfidf,
'nearest': nearest,
'total': catboost_answer
}

177
nearest-search-train.ipynb Normal file
View File

@ -0,0 +1,177 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n",
"To disable this warning, you can either:\n",
"\t- Avoid using `tokenizers` before the fork if possible\n",
"\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n",
"Requirement already satisfied: annoy in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (1.17.3)\n",
"Requirement already satisfied: sentence_transformers in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (2.2.2)\n",
"Requirement already satisfied: transformers<5.0.0,>=4.6.0 in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from sentence_transformers) (4.32.1)\n",
"Requirement already satisfied: tqdm in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from sentence_transformers) (4.66.1)\n",
"Requirement already satisfied: torch>=1.6.0 in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from sentence_transformers) (2.0.1)\n",
"Requirement already satisfied: torchvision in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from sentence_transformers) (0.15.2)\n",
"Requirement already satisfied: numpy in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from sentence_transformers) (1.25.2)\n",
"Requirement already satisfied: scikit-learn in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from sentence_transformers) (1.3.0)\n",
"Requirement already satisfied: scipy in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from sentence_transformers) (1.11.2)\n",
"Requirement already satisfied: nltk in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from sentence_transformers) (3.8.1)\n",
"Requirement already satisfied: sentencepiece in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from sentence_transformers) (0.1.99)\n",
"Requirement already satisfied: huggingface-hub>=0.4.0 in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from sentence_transformers) (0.16.4)\n",
"Requirement already satisfied: filelock in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from huggingface-hub>=0.4.0->sentence_transformers) (3.12.3)\n",
"Requirement already satisfied: fsspec in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from huggingface-hub>=0.4.0->sentence_transformers) (2023.6.0)\n",
"Requirement already satisfied: requests in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from huggingface-hub>=0.4.0->sentence_transformers) (2.31.0)\n",
"Requirement already satisfied: pyyaml>=5.1 in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from huggingface-hub>=0.4.0->sentence_transformers) (6.0.1)\n",
"Requirement already satisfied: typing-extensions>=3.7.4.3 in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from huggingface-hub>=0.4.0->sentence_transformers) (4.7.1)\n",
"Requirement already satisfied: packaging>=20.9 in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from huggingface-hub>=0.4.0->sentence_transformers) (23.1)\n",
"Requirement already satisfied: sympy in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from torch>=1.6.0->sentence_transformers) (1.12)\n",
"Requirement already satisfied: networkx in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from torch>=1.6.0->sentence_transformers) (3.1)\n",
"Requirement already satisfied: jinja2 in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from torch>=1.6.0->sentence_transformers) (3.1.2)\n",
"Requirement already satisfied: regex!=2019.12.17 in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from transformers<5.0.0,>=4.6.0->sentence_transformers) (2023.8.8)\n",
"Requirement already satisfied: tokenizers!=0.11.3,<0.14,>=0.11.1 in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from transformers<5.0.0,>=4.6.0->sentence_transformers) (0.13.3)\n",
"Requirement already satisfied: safetensors>=0.3.1 in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from transformers<5.0.0,>=4.6.0->sentence_transformers) (0.3.3)\n",
"Requirement already satisfied: click in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from nltk->sentence_transformers) (8.1.7)\n",
"Requirement already satisfied: joblib in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from nltk->sentence_transformers) (1.3.2)\n",
"Requirement already satisfied: threadpoolctl>=2.0.0 in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from scikit-learn->sentence_transformers) (3.2.0)\n",
"Requirement already satisfied: pillow!=8.3.*,>=5.3.0 in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from torchvision->sentence_transformers) (10.0.0)\n",
"Requirement already satisfied: MarkupSafe>=2.0 in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from jinja2->torch>=1.6.0->sentence_transformers) (2.1.3)\n",
"Requirement already satisfied: charset-normalizer<4,>=2 in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from requests->huggingface-hub>=0.4.0->sentence_transformers) (3.2.0)\n",
"Requirement already satisfied: idna<4,>=2.5 in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from requests->huggingface-hub>=0.4.0->sentence_transformers) (3.4)\n",
"Requirement already satisfied: urllib3<3,>=1.21.1 in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from requests->huggingface-hub>=0.4.0->sentence_transformers) (2.0.4)\n",
"Requirement already satisfied: certifi>=2017.4.17 in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from requests->huggingface-hub>=0.4.0->sentence_transformers) (2023.7.22)\n",
"Requirement already satisfied: mpmath>=0.19 in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from sympy->torch>=1.6.0->sentence_transformers) (1.3.0)\n"
]
}
],
"source": [
"!pip3 install annoy sentence_transformers"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
" from .autonotebook import tqdm as notebook_tqdm\n"
]
}
],
"source": [
"from sentence_transformers import SentenceTransformer\n",
"model = SentenceTransformer('sentence-transformers/LaBSE')\n",
"\n",
"def get_embs(text):\n",
" embeddings = model.encode(text)\n",
" return embeddings"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"import pickle\n",
"with open('n_labels.pickle', 'rb') as file:\n",
" n_labels = pickle.load(file)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"with open('annoy_labels.pickle', 'rb') as file:\n",
" labels = pickle.load(file)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"embs = []\n",
"for text in n_labels:\n",
" embs.append(get_embs(text))\n",
" if len(embs) % 50 == 0:\n",
" print(len(embs))"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"from annoy import AnnoyIndex\n",
"index = AnnoyIndex(768, 'angular')"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"for i, emb in enumerate(embs):\n",
" index.add_item(i, emb)"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"True"
]
},
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"index.build(20)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.5"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}

123
req.txt Normal file
View File

@ -0,0 +1,123 @@
accelerate==0.22.0
aiohttp==3.8.5
aiosignal==1.3.1
annotated-types==0.5.0
annoy==1.17.3
anyio==3.7.1
appnope==0.1.3
asttokens==2.2.1
async-timeout==4.0.3
attrs==23.1.0
backcall==0.2.0
beautifulsoup4==4.12.2
captum==0.6.0
certifi==2023.7.22
charset-normalizer==3.2.0
click==8.1.7
comm==0.1.4
contourpy==1.1.0
cycler==0.11.0
datasets==2.14.4
debugpy==1.6.7.post1
decorator==5.1.1
dill==0.3.7
et-xmlfile==1.1.0
evaluate==0.4.0
exceptiongroup==1.1.3
executing==1.2.0
fastapi==0.103.1
filelock==3.12.3
fonttools==4.42.1
frozenlist==1.4.0
fsspec==2023.6.0
h11==0.14.0
httptools==0.6.0
huggingface-hub==0.16.4
icu==0.0.1
idna==3.4
ipykernel==6.25.1
ipython==7.34.0
jedi==0.19.0
Jinja2==3.1.2
joblib==1.3.2
jupyter_client==8.3.1
jupyter_core==5.3.1
kiwisolver==1.4.5
MarkupSafe==2.1.3
matplotlib==3.7.2
matplotlib-inline==0.1.6
Morfessor==2.0.6
mpmath==1.3.0
multidict==6.0.4
multiprocess==0.70.15
nest-asyncio==1.5.7
networkx==3.1
nltk==3.8.1
numpy==1.25.2
openpyxl==3.1.2
outcome==1.2.0
packaging==23.1
pandas==2.1.0
parso==0.8.3
pexpect==4.8.0
pickleshare==0.7.5
Pillow==10.0.0
platformdirs==3.10.0
polyglot==16.7.4
prompt-toolkit==3.0.39
psutil==5.9.5
ptyprocess==0.7.0
pure-eval==0.2.2
pyarrow==13.0.0
pydantic==2.3.0
pydantic_core==2.6.3
Pygments==2.16.1
pymystem3==0.2.0
pyparsing==3.0.9
PySocks==1.7.1
python-dateutil==2.8.2
python-dotenv==1.0.0
pytz==2023.3
PyYAML==6.0.1
pyzmq==25.1.1
regex==2023.8.8
requests==2.31.0
responses==0.18.0
safetensors==0.3.3
scikit-learn==1.3.0
scipy==1.11.2
seaborn==0.12.2
selenium==4.12.0
sentence-transformers==2.2.2
sentencepiece==0.1.99
six==1.16.0
sniffio==1.3.0
sortedcontainers==2.4.0
soupsieve==2.5
stack-data==0.6.2
starlette==0.27.0
sympy==1.12
threadpoolctl==3.2.0
tokenizers==0.13.3
torch==2.0.1
torchvision==0.15.2
tornado==6.3.3
tqdm==4.66.1
traitlets==5.9.0
transformers==4.32.1
transformers-interpret==0.10.0
transliterate==1.10.2
trio==0.22.2
trio-websocket==0.10.3
typing_extensions==4.7.1
tzdata==2023.3
undetected-chromedriver==3.5.3
urllib3==2.0.4
uvicorn==0.23.2
uvloop==0.17.0
watchfiles==0.20.0
wcwidth==0.2.6
websockets==11.0.3
wsproto==1.2.0
xxhash==3.3.0
yarl==1.9.2

208
tfidf-train.ipynb Normal file
View File

@ -0,0 +1,208 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Requirement already satisfied: pymystem3 in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (0.2.0)\n",
"Requirement already satisfied: pandas in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (2.1.0)\n",
"Collecting sklearn\n",
" Downloading sklearn-0.0.post9.tar.gz (3.6 kB)\n",
" Installing build dependencies ... \u001b[?25ldone\n",
"\u001b[?25h Getting requirements to build wheel ... \u001b[?25ldone\n",
"\u001b[?25h Preparing metadata (pyproject.toml) ... \u001b[?25ldone\n",
"\u001b[?25hRequirement already satisfied: requests in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from pymystem3) (2.31.0)\n",
"Requirement already satisfied: numpy>=1.23.2 in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from pandas) (1.25.2)\n",
"Requirement already satisfied: python-dateutil>=2.8.2 in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from pandas) (2.8.2)\n",
"Requirement already satisfied: pytz>=2020.1 in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from pandas) (2023.3)\n",
"Requirement already satisfied: tzdata>=2022.1 in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from pandas) (2023.3)\n",
"Requirement already satisfied: six>=1.5 in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from python-dateutil>=2.8.2->pandas) (1.16.0)\n",
"Requirement already satisfied: charset-normalizer<4,>=2 in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from requests->pymystem3) (3.2.0)\n",
"Requirement already satisfied: idna<4,>=2.5 in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from requests->pymystem3) (3.4)\n",
"Requirement already satisfied: urllib3<3,>=1.21.1 in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from requests->pymystem3) (2.0.4)\n",
"Requirement already satisfied: certifi>=2017.4.17 in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from requests->pymystem3) (2023.7.22)\n",
"Building wheels for collected packages: sklearn\n",
" Building wheel for sklearn (pyproject.toml) ... \u001b[?25ldone\n",
"\u001b[?25h Created wheel for sklearn: filename=sklearn-0.0.post9-py3-none-any.whl size=2952 sha256=de085da5188e0680130af47d37bf6a7803a4dbec121af8adf834ac3d03747231\n",
" Stored in directory: /Users/ilya/Library/Caches/pip/wheels/ef/63/d1/f1671e1e93b7ef4d35df483f9b2485e6dd21941da9a92296fb\n",
"Successfully built sklearn\n",
"Installing collected packages: sklearn\n",
"Successfully installed sklearn-0.0.post9\n"
]
}
],
"source": [
"!pip3 install pymystem3 pandas sklearn"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import pickle\n",
"with open('dss.pickle', 'rb') as file:\n",
" data = pickle.load(file)"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"import re\n",
"import pandas as pd\n",
"from pymystem3 import Mystem\n",
"\n",
"def clean(text):\n",
" text = re.sub('[^а-яё ]', ' ', str(text).lower())\n",
" text = re.sub(r\" +\", \" \", text).strip()\n",
" return text"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"texts = list(map(lambda x: clean(x['text']), data))\n",
"labels = list(map(lambda x: x['label'], data))"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"mstm = Mystem()\n",
"normalized = [''.join(mstm.lemmatize(t)[:-1]) for t in texts]"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"import pandas as pd\n",
"df = pd.DataFrame()\n",
"df['text'] = texts\n",
"df['norm'] = normalized\n",
"df['label'] = labels"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"from sklearn.model_selection import train_test_split\n",
"\n",
"train, test = train_test_split(df, test_size=0.1, random_state=42)\n",
"valid, test = train_test_split(test, test_size=0.2, random_state=42)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"from sklearn.feature_extraction.text import TfidfVectorizer\n",
"\n",
"model_tfidf = TfidfVectorizer(max_features=5000)\n",
"\n",
"train_tfidf = model_tfidf.fit_transform(train['norm'].values)\n",
"valid_tfidf = model_tfidf.transform(valid['norm'].values)\n",
"test_tfidf = model_tfidf.transform(test['norm'].values)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<style>#sk-container-id-1 {color: black;}#sk-container-id-1 pre{padding: 0;}#sk-container-id-1 div.sk-toggleable {background-color: white;}#sk-container-id-1 label.sk-toggleable__label {cursor: pointer;display: block;width: 100%;margin-bottom: 0;padding: 0.3em;box-sizing: border-box;text-align: center;}#sk-container-id-1 label.sk-toggleable__label-arrow:before {content: \"▸\";float: left;margin-right: 0.25em;color: #696969;}#sk-container-id-1 label.sk-toggleable__label-arrow:hover:before {color: black;}#sk-container-id-1 div.sk-estimator:hover label.sk-toggleable__label-arrow:before {color: black;}#sk-container-id-1 div.sk-toggleable__content {max-height: 0;max-width: 0;overflow: hidden;text-align: left;background-color: #f0f8ff;}#sk-container-id-1 div.sk-toggleable__content pre {margin: 0.2em;color: black;border-radius: 0.25em;background-color: #f0f8ff;}#sk-container-id-1 input.sk-toggleable__control:checked~div.sk-toggleable__content {max-height: 200px;max-width: 100%;overflow: auto;}#sk-container-id-1 input.sk-toggleable__control:checked~label.sk-toggleable__label-arrow:before {content: \"▾\";}#sk-container-id-1 div.sk-estimator input.sk-toggleable__control:checked~label.sk-toggleable__label {background-color: #d4ebff;}#sk-container-id-1 div.sk-label input.sk-toggleable__control:checked~label.sk-toggleable__label {background-color: #d4ebff;}#sk-container-id-1 input.sk-hidden--visually {border: 0;clip: rect(1px 1px 1px 1px);clip: rect(1px, 1px, 1px, 1px);height: 1px;margin: -1px;overflow: hidden;padding: 0;position: absolute;width: 1px;}#sk-container-id-1 div.sk-estimator {font-family: monospace;background-color: #f0f8ff;border: 1px dotted black;border-radius: 0.25em;box-sizing: border-box;margin-bottom: 0.5em;}#sk-container-id-1 div.sk-estimator:hover {background-color: #d4ebff;}#sk-container-id-1 div.sk-parallel-item::after {content: \"\";width: 100%;border-bottom: 1px solid gray;flex-grow: 1;}#sk-container-id-1 div.sk-label:hover label.sk-toggleable__label {background-color: #d4ebff;}#sk-container-id-1 div.sk-serial::before {content: \"\";position: absolute;border-left: 1px solid gray;box-sizing: border-box;top: 0;bottom: 0;left: 50%;z-index: 0;}#sk-container-id-1 div.sk-serial {display: flex;flex-direction: column;align-items: center;background-color: white;padding-right: 0.2em;padding-left: 0.2em;position: relative;}#sk-container-id-1 div.sk-item {position: relative;z-index: 1;}#sk-container-id-1 div.sk-parallel {display: flex;align-items: stretch;justify-content: center;background-color: white;position: relative;}#sk-container-id-1 div.sk-item::before, #sk-container-id-1 div.sk-parallel-item::before {content: \"\";position: absolute;border-left: 1px solid gray;box-sizing: border-box;top: 0;bottom: 0;left: 50%;z-index: -1;}#sk-container-id-1 div.sk-parallel-item {display: flex;flex-direction: column;z-index: 1;position: relative;background-color: white;}#sk-container-id-1 div.sk-parallel-item:first-child::after {align-self: flex-end;width: 50%;}#sk-container-id-1 div.sk-parallel-item:last-child::after {align-self: flex-start;width: 50%;}#sk-container-id-1 div.sk-parallel-item:only-child::after {width: 0;}#sk-container-id-1 div.sk-dashed-wrapped {border: 1px dashed gray;margin: 0 0.4em 0.5em 0.4em;box-sizing: border-box;padding-bottom: 0.4em;background-color: white;}#sk-container-id-1 div.sk-label label {font-family: monospace;font-weight: bold;display: inline-block;line-height: 1.2em;}#sk-container-id-1 div.sk-label-container {text-align: center;}#sk-container-id-1 div.sk-container {/* jupyter's `normalize.less` sets `[hidden] { display: none; }` but bootstrap.min.css set `[hidden] { display: none !important; }` so we also need the `!important` here to be able to override the default hidden behavior on the sphinx rendered scikit-learn.org. See: https://github.com/scikit-learn/scikit-learn/issues/21755 */display: inline-block !important;position: relative;}#sk-container-id-1 div.sk-text-repr-fallback {display: none;}</style><div id=\"sk-container-id-1\" class=\"sk-top-container\"><div class=\"sk-text-repr-fallback\"><pre>RandomForestClassifier(random_state=42)</pre><b>In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook. <br />On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.</b></div><div class=\"sk-container\" hidden><div class=\"sk-item\"><div class=\"sk-estimator sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"sk-estimator-id-1\" type=\"checkbox\" checked><label for=\"sk-estimator-id-1\" class=\"sk-toggleable__label sk-toggleable__label-arrow\">RandomForestClassifier</label><div class=\"sk-toggleable__content\"><pre>RandomForestClassifier(random_state=42)</pre></div></div></div></div></div>"
],
"text/plain": [
"RandomForestClassifier(random_state=42)"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from sklearn.ensemble import RandomForestClassifier\n",
"\n",
"cls = RandomForestClassifier(random_state=42)\n",
"cls.fit(train_tfidf, train['label'].values)"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"0.629399514876379"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from sklearn.metrics import f1_score\n",
"predictions = cls.predict(test_tfidf)\n",
"f1_score(predictions, test['label'].values, average='weighted')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.5"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}