mirror of
https://github.com/magnum-opus-nn-cp/ml.git
synced 2024-11-21 23:46:33 +03:00
upload ml
This commit is contained in:
commit
e548bd4ad7
10
README.md
Normal file
10
README.md
Normal file
|
@ -0,0 +1,10 @@
|
|||
### Установить зависимости
|
||||
`pip3 install -r req.txt`
|
||||
### Локации документов для обучения
|
||||
- `catboost-train.ipynb` - файл для обучения кэтбуста
|
||||
- `bert_train.ipynb` - файл для обучения берта
|
||||
- `nearest-search-train.ipynb` - файл для обучения поиска близжайших соседей
|
||||
- `tfidf-train.ipynb` - файл для обучения tf-idf + random forest
|
||||
### Инференс модели
|
||||
`uvicorn inference:app --reload --workers 1`
|
||||
### У мля слишком много больших файлов поэтому мы выложили код с весами моделей на гугл диск https://drive.google.com/drive/folders/1hnWKpZjtQLBbzAE9YsUW_4x-IEb3mFvg?usp=sharing
|
BIN
annoy_labels.pickle
Normal file
BIN
annoy_labels.pickle
Normal file
Binary file not shown.
287
bert_train.ipynb
Normal file
287
bert_train.ipynb
Normal file
|
@ -0,0 +1,287 @@
|
|||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 11,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n",
|
||||
"To disable this warning, you can either:\n",
|
||||
"\t- Avoid using `tokenizers` before the fork if possible\n",
|
||||
"\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n",
|
||||
"Requirement already satisfied: transformers in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (4.32.1)\n",
|
||||
"Requirement already satisfied: datasets in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (2.14.4)\n",
|
||||
"Requirement already satisfied: pandas in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (2.1.0)\n",
|
||||
"Requirement already satisfied: evaluate in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (0.4.0)\n",
|
||||
"Requirement already satisfied: numpy in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (1.25.2)\n",
|
||||
"Requirement already satisfied: filelock in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from transformers) (3.12.3)\n",
|
||||
"Requirement already satisfied: huggingface-hub<1.0,>=0.15.1 in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from transformers) (0.16.4)\n",
|
||||
"Requirement already satisfied: packaging>=20.0 in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from transformers) (23.1)\n",
|
||||
"Requirement already satisfied: pyyaml>=5.1 in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from transformers) (6.0.1)\n",
|
||||
"Requirement already satisfied: regex!=2019.12.17 in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from transformers) (2023.8.8)\n",
|
||||
"Requirement already satisfied: requests in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from transformers) (2.31.0)\n",
|
||||
"Requirement already satisfied: tokenizers!=0.11.3,<0.14,>=0.11.1 in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from transformers) (0.13.3)\n",
|
||||
"Requirement already satisfied: safetensors>=0.3.1 in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from transformers) (0.3.3)\n",
|
||||
"Requirement already satisfied: tqdm>=4.27 in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from transformers) (4.66.1)\n",
|
||||
"Requirement already satisfied: pyarrow>=8.0.0 in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from datasets) (13.0.0)\n",
|
||||
"Requirement already satisfied: dill<0.3.8,>=0.3.0 in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from datasets) (0.3.7)\n",
|
||||
"Requirement already satisfied: xxhash in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from datasets) (3.3.0)\n",
|
||||
"Requirement already satisfied: multiprocess in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from datasets) (0.70.15)\n",
|
||||
"Requirement already satisfied: fsspec[http]>=2021.11.1 in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from datasets) (2023.6.0)\n",
|
||||
"Requirement already satisfied: aiohttp in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from datasets) (3.8.5)\n",
|
||||
"Requirement already satisfied: python-dateutil>=2.8.2 in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from pandas) (2.8.2)\n",
|
||||
"Requirement already satisfied: pytz>=2020.1 in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from pandas) (2023.3)\n",
|
||||
"Requirement already satisfied: tzdata>=2022.1 in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from pandas) (2023.3)\n",
|
||||
"Requirement already satisfied: responses<0.19 in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from evaluate) (0.18.0)\n",
|
||||
"Requirement already satisfied: attrs>=17.3.0 in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from aiohttp->datasets) (23.1.0)\n",
|
||||
"Requirement already satisfied: charset-normalizer<4.0,>=2.0 in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from aiohttp->datasets) (3.2.0)\n",
|
||||
"Requirement already satisfied: multidict<7.0,>=4.5 in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from aiohttp->datasets) (6.0.4)\n",
|
||||
"Requirement already satisfied: async-timeout<5.0,>=4.0.0a3 in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from aiohttp->datasets) (4.0.3)\n",
|
||||
"Requirement already satisfied: yarl<2.0,>=1.0 in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from aiohttp->datasets) (1.9.2)\n",
|
||||
"Requirement already satisfied: frozenlist>=1.1.1 in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from aiohttp->datasets) (1.4.0)\n",
|
||||
"Requirement already satisfied: aiosignal>=1.1.2 in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from aiohttp->datasets) (1.3.1)\n",
|
||||
"Requirement already satisfied: typing-extensions>=3.7.4.3 in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from huggingface-hub<1.0,>=0.15.1->transformers) (4.7.1)\n",
|
||||
"Requirement already satisfied: six>=1.5 in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from python-dateutil>=2.8.2->pandas) (1.16.0)\n",
|
||||
"Requirement already satisfied: idna<4,>=2.5 in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from requests->transformers) (3.4)\n",
|
||||
"Requirement already satisfied: urllib3<3,>=1.21.1 in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from requests->transformers) (2.0.4)\n",
|
||||
"Requirement already satisfied: certifi>=2017.4.17 in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from requests->transformers) (2023.7.22)\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"!pip3 install transformers datasets pandas evaluate numpy"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"label2id = {\n",
|
||||
" 'AAA(RU)': 0,\n",
|
||||
" 'AA(RU)': 1, \n",
|
||||
" 'A+(RU)': 2,\n",
|
||||
" 'A(RU)': 3,\n",
|
||||
" 'A-(RU)': 4,\n",
|
||||
" 'BBB+(RU)': 5,\n",
|
||||
" 'BBB(RU)': 6, \n",
|
||||
" 'AA+(RU)': 7,\n",
|
||||
" 'BBB-(RU)': 8,\n",
|
||||
" 'AA-(RU)': 9,\n",
|
||||
" 'BB+(RU)': 10, \n",
|
||||
" 'BB-(RU)': 11, \n",
|
||||
" 'B+(RU)': 12,\n",
|
||||
" 'BB(RU)': 13, \n",
|
||||
" 'B(RU)': 14,\n",
|
||||
" 'B-(RU)': 15, \n",
|
||||
" 'C(RU)': 16\n",
|
||||
"}\n",
|
||||
"id2label = {0: 'AAA(RU)',\n",
|
||||
" 1: 'AA(RU)',\n",
|
||||
" 2: 'A+(RU)',\n",
|
||||
" 3: 'A(RU)',\n",
|
||||
" 4: 'A-(RU)',\n",
|
||||
" 5: 'BBB+(RU)',\n",
|
||||
" 6: 'BBB(RU)',\n",
|
||||
" 7: 'AA+(RU)',\n",
|
||||
" 8: 'BBB-(RU)',\n",
|
||||
" 9: 'AA-(RU)',\n",
|
||||
" 10: 'BB+(RU)',\n",
|
||||
" 11: 'BB-(RU)',\n",
|
||||
" 12: 'B+(RU)',\n",
|
||||
" 13: 'BB(RU)',\n",
|
||||
" 14: 'B(RU)',\n",
|
||||
" 15: 'B-(RU)',\n",
|
||||
" 16: 'C(RU)'}"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import pickle\n",
|
||||
"with open('dss.pickle', 'rb') as file:\n",
|
||||
" data = pickle.load(file)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"{'text': 'Кредитный рейтинг АКБ Энергобанк ПАО далее Энергобанк Банк обусловлен удовлетворительным бизнеспрофилем в сочетании с сильной достаточностью капитала критическим рискпрофилем и удовлетворительной оценкой ликвидности и фондирования учитывающей концентрацию обязательств Банка на средствах крупнейших кредиторовОсновная деятельность Энергобанка сконцентрирована в Республике Татарстан далее РТ где он занимает устойчивые рыночные позиции занял шестое место по размеру активов на На российском банковском рынке Банк имеет невысокую долю на он занимал е место по величине собственного капитала и е место по размеру активовКлючевыми направлениями деятельности Банка являются кредитование предприятий агропромышленного комплекса строительства и торговли а также залоговое розничное кредитование Контролирующими собственниками Банка являются ИН Хайруллин и АН Хайруллин владеющие около акций через АО Эдельвейс КорпорейшнОценка бизнеспрофиля Банка отражает его относительно невысокую долю на российском рынке банковских услуг и выраженную региональную направленность его деятельности несмотря на планы по открытию отделений за пределами РТОперационный доход Банка характеризуется низким хотя и возрастающим уровнем диверсификации по итогам года значение индекса ХерфиндаляХиршмана составило и свидетельствовало о повышенной концентрации на кредитовании корпоративного сектора около операционного дохода Качество управления Банком оценивается АКРА как удовлетворительное и соответствующее среднему уровню в российском банковском секторе в целом Организационная структура Энергобанка соответствует масштабам и особенностям его бизнеса Структура собственности Банка прозрачна при этом отмечается связанность его операций с аффилированными с собственниками Банка компаниямиСтратегия Банка на период до конца года предполагает планомерное наращивание размера активов за',\n",
|
||||
" 'label': 11}"
|
||||
]
|
||||
},
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"data[0]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
|
||||
" from .autonotebook import tqdm as notebook_tqdm\n",
|
||||
"Some weights of BertForSequenceClassification were not initialized from the model checkpoint at cointegrated/rubert-tiny and are newly initialized: ['classifier.weight', 'classifier.bias']\n",
|
||||
"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from transformers import AutoModelForSequenceClassification\n",
|
||||
"model = AutoModelForSequenceClassification.from_pretrained(\n",
|
||||
" \"cointegrated/rubert-tiny\", num_labels=len(id2label.keys()), id2label=id2label, label2id=label2id\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from transformers import AutoTokenizer\n",
|
||||
"tokenizer = AutoTokenizer.from_pretrained(\"cointegrated/rubert-tiny\")\n",
|
||||
"def token(text):\n",
|
||||
" return tokenizer(text['text'], padding=True, truncation=True, max_length=512, return_tensors='pt')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Map: 100%|██████████| 11500/11500 [00:04<00:00, 2842.10 examples/s]\n",
|
||||
"Map: 100%|██████████| 1845/1845 [00:00<00:00, 3125.02 examples/s]\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from datasets import Dataset\n",
|
||||
"import pandas as pd\n",
|
||||
"from random import shuffle\n",
|
||||
"shuffle(data)\n",
|
||||
"train = data[:11500]\n",
|
||||
"test = data[11500:]\n",
|
||||
"train = Dataset.from_pandas(pd.DataFrame(data=train))\n",
|
||||
"test = Dataset.from_pandas(pd.DataFrame(data=test))\n",
|
||||
"tokenized_train = train.map(token, batched=True)\n",
|
||||
"tokenized_test = test.map(token, batched=True)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import evaluate\n",
|
||||
"import numpy as np\n",
|
||||
"\n",
|
||||
"f1 = evaluate.load(\"f1\")\n",
|
||||
"\n",
|
||||
"def compute_metrics(eval_pred):\n",
|
||||
" predictions, labels = eval_pred\n",
|
||||
" predictions = np.argmax(predictions, axis=1)\n",
|
||||
" return f1.compute(predictions=predictions, references=labels, average='macro')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 9,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Found safetensors installation, but --save_safetensors=False. Safetensors should be a preferred weights saving format due to security and performance reasons. If your model cannot be saved by safetensors please feel free to open an issue at https://github.com/huggingface/safetensors!\n",
|
||||
"PyTorch: setting up devices\n",
|
||||
"The default value for the training argument `--report_to` will change in v5 (from all installed integrations to none). In v5, you will need to use `--report_to all` to get the same behavior as now. You should start updating your code and make this info disappear :-).\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from transformers import TrainingArguments, Trainer, DataCollatorWithPadding\n",
|
||||
"\n",
|
||||
"data_collator = DataCollatorWithPadding(tokenizer=tokenizer)\n",
|
||||
"training_args = TrainingArguments(\n",
|
||||
" output_dir=\"akra_model\",\n",
|
||||
" learning_rate=2e-5,\n",
|
||||
" per_device_train_batch_size=16,\n",
|
||||
" per_device_eval_batch_size=16,\n",
|
||||
" num_train_epochs=10,\n",
|
||||
" weight_decay=0.01,\n",
|
||||
" evaluation_strategy=\"epoch\",\n",
|
||||
" save_strategy=\"epoch\",\n",
|
||||
" load_best_model_at_end=True,\n",
|
||||
")\n",
|
||||
"trainer = Trainer(\n",
|
||||
" model=model,\n",
|
||||
" args=training_args,\n",
|
||||
" train_dataset=tokenized_train,\n",
|
||||
" eval_dataset=tokenized_test,\n",
|
||||
" tokenizer=tokenizer,\n",
|
||||
" data_collator=data_collator,\n",
|
||||
" compute_metrics=compute_metrics,\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"trainer.train()"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.11.5"
|
||||
},
|
||||
"orig_nbformat": 4
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
161
catboost-train.ipynb
Normal file
161
catboost-train.ipynb
Normal file
|
@ -0,0 +1,161 @@
|
|||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Requirement already satisfied: catboost in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (1.2.1)\n",
|
||||
"Requirement already satisfied: graphviz in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from catboost) (0.20.1)\n",
|
||||
"Requirement already satisfied: matplotlib in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from catboost) (3.7.2)\n",
|
||||
"Requirement already satisfied: numpy>=1.16.0 in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from catboost) (1.25.2)\n",
|
||||
"Requirement already satisfied: pandas>=0.24 in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from catboost) (2.1.0)\n",
|
||||
"Requirement already satisfied: scipy in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from catboost) (1.11.2)\n",
|
||||
"Requirement already satisfied: plotly in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from catboost) (5.16.1)\n",
|
||||
"Requirement already satisfied: six in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from catboost) (1.16.0)\n",
|
||||
"Requirement already satisfied: python-dateutil>=2.8.2 in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from pandas>=0.24->catboost) (2.8.2)\n",
|
||||
"Requirement already satisfied: pytz>=2020.1 in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from pandas>=0.24->catboost) (2023.3)\n",
|
||||
"Requirement already satisfied: tzdata>=2022.1 in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from pandas>=0.24->catboost) (2023.3)\n",
|
||||
"Requirement already satisfied: contourpy>=1.0.1 in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from matplotlib->catboost) (1.1.0)\n",
|
||||
"Requirement already satisfied: cycler>=0.10 in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from matplotlib->catboost) (0.11.0)\n",
|
||||
"Requirement already satisfied: fonttools>=4.22.0 in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from matplotlib->catboost) (4.42.1)\n",
|
||||
"Requirement already satisfied: kiwisolver>=1.0.1 in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from matplotlib->catboost) (1.4.5)\n",
|
||||
"Requirement already satisfied: packaging>=20.0 in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from matplotlib->catboost) (23.1)\n",
|
||||
"Requirement already satisfied: pillow>=6.2.0 in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from matplotlib->catboost) (10.0.0)\n",
|
||||
"Requirement already satisfied: pyparsing<3.1,>=2.3.1 in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from matplotlib->catboost) (3.0.9)\n",
|
||||
"Requirement already satisfied: tenacity>=6.2.0 in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from plotly->catboost) (8.2.3)\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"!pip3 install catboost"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import pickle\n",
|
||||
"\n",
|
||||
"with open('cb.pickle', 'rb') as file:\n",
|
||||
" cb_dataset = pickle.load(file)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from catboost import CatBoostClassifier"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"model = CatBoostClassifier(iterations=50000,\n",
|
||||
" learning_rate=1e-2,\n",
|
||||
" depth=8)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"label2id = {\n",
|
||||
" 'AAA(RU)': 0,\n",
|
||||
" 'AA(RU)': 1, \n",
|
||||
" 'A+(RU)': 2,\n",
|
||||
" 'A(RU)': 3,\n",
|
||||
" 'A-(RU)': 4,\n",
|
||||
" 'BBB+(RU)': 5,\n",
|
||||
" 'BBB(RU)': 6, \n",
|
||||
" 'AA+(RU)': 7,\n",
|
||||
" 'BBB-(RU)': 8,\n",
|
||||
" 'AA-(RU)': 9,\n",
|
||||
" 'BB+(RU)': 10, \n",
|
||||
" 'BB-(RU)': 11, \n",
|
||||
" 'B+(RU)': 12,\n",
|
||||
" 'BB(RU)': 13, \n",
|
||||
" 'B(RU)': 14,\n",
|
||||
" 'B-(RU)': 15, \n",
|
||||
" 'C(RU)': 16\n",
|
||||
"}"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"train = []\n",
|
||||
"train_labels = []\n",
|
||||
"for i in cb_dataset[0:6500]:\n",
|
||||
" train.append([label2id[i['outs'][0]['answer']], i['outs'][0]['metric'], i['outs'][1]['metric'], label2id[i['outs'][1]['answer']], i['outs'][2]['metric'], label2id[i['outs'][2]['answer']]])\n",
|
||||
" if not isinstance(i['label'], int):\n",
|
||||
" train_labels.append(label2id[i['label'] + '(RU)'])\n",
|
||||
" else:\n",
|
||||
" train_labels.append(i['label'])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"test = []\n",
|
||||
"test_labels = []\n",
|
||||
"for i in cb_dataset[6500:]:\n",
|
||||
" test.append([label2id[i['outs'][0]['answer']], i['outs'][0]['metric'], i['outs'][1]['metric'], label2id[i['outs'][1]['answer']], i['outs'][2]['metric'], label2id[i['outs'][2]['answer']]])\n",
|
||||
" if not isinstance(i['label'], int):\n",
|
||||
" test_labels.append(label2id[i['label'] + '(RU)'])\n",
|
||||
" else:\n",
|
||||
" test_labels.append(i['label'])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"model.fit(train, train_labels)"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.11.5"
|
||||
},
|
||||
"orig_nbformat": 4
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
529
inference.py
Normal file
529
inference.py
Normal file
|
@ -0,0 +1,529 @@
|
|||
import torch
|
||||
from transformers import AutoModelForSequenceClassification
|
||||
from transformers import AutoTokenizer, BertConfig
|
||||
from captum.attr import LayerIntegratedGradients
|
||||
import re
|
||||
import torch
|
||||
import numpy as np
|
||||
from collections import Counter
|
||||
from fastapi import FastAPI
|
||||
from pydantic import BaseModel
|
||||
import pickle
|
||||
from matplotlib.colors import LinearSegmentedColormap
|
||||
from catboost import CatBoostClassifier
|
||||
|
||||
from pymystem3 import Mystem
|
||||
from sklearn.feature_extraction.text import TfidfVectorizer
|
||||
from sklearn.ensemble import RandomForestClassifier
|
||||
from matplotlib.colors import LinearSegmentedColormap
|
||||
|
||||
from sentence_transformers import SentenceTransformer
|
||||
|
||||
sentence_model = SentenceTransformer('sentence-transformers/LaBSE')
|
||||
|
||||
|
||||
catboost = CatBoostClassifier().load_model('catboost')
|
||||
|
||||
def get_embs(text):
|
||||
embeddings = sentence_model.encode(text)
|
||||
return embeddings
|
||||
|
||||
cmap = LinearSegmentedColormap.from_list('rg',["w", "g"], N=512)
|
||||
mstm = Mystem()
|
||||
|
||||
|
||||
with open('vectorizer.pickle', 'rb') as file:
|
||||
model_tfidf = pickle.load(file)
|
||||
|
||||
with open('tree.pickle', 'rb') as file:
|
||||
cls = pickle.load(file)
|
||||
|
||||
|
||||
def resolve_text(tokens, text):
|
||||
words = text.split()
|
||||
tokens_values = list(map(lambda tok: tok[0], tokens))
|
||||
tokens_metrics = list(map(lambda tok: tok[1], tokens))
|
||||
resolved = []
|
||||
for i, word in enumerate(words):
|
||||
try:
|
||||
if mstm.lemmatize(word)[0] in tokens_values:
|
||||
try:
|
||||
value = tokens_metrics[tokens_values.index(mstm.lemmatize(word)[0])]
|
||||
#color = from_abs_to_rgb(min(tokens_metrics), max(tokens_metrics), value)
|
||||
resolved.append(f'<span data-value="{(value - min(tokens_metrics))/ max(tokens_metrics)}">{word}</span>')
|
||||
except:
|
||||
resolved.append(word)
|
||||
else:
|
||||
resolved.append(word)
|
||||
except:
|
||||
resolved.append(word)
|
||||
return ' '.join(resolved)
|
||||
|
||||
|
||||
def process_classify(text):
|
||||
if not len(text.replace(' ', '')): return {'ans': 0, 'text': ''}
|
||||
try:
|
||||
normalized = ''.join(mstm.lemmatize(text)[:-1])
|
||||
except: return {'ans': 0, 'text': ''}
|
||||
tf_idfed = model_tfidf.transform(np.array([normalized]))[0]
|
||||
|
||||
|
||||
ans = cls.predict(tf_idfed)[0]
|
||||
return {'ans': ans, 'text': ""}
|
||||
|
||||
|
||||
def process_embedding(text):
|
||||
if not len(text.replace(' ', '')): return {'ans': 0, 'text': ''}
|
||||
try:
|
||||
normalized = ''.join(mstm.lemmatize(text)[:-1])
|
||||
except: return {'ans': 0, 'text': ''}
|
||||
tf_idfed = model_tfidf.transform(np.array([normalized]))[0]
|
||||
values = []
|
||||
for i in range(5000):
|
||||
values.append(tf_idfed.todense()[0, i])
|
||||
|
||||
important_tokens = []
|
||||
for i, val in enumerate(values):
|
||||
if val > (np.min(values) + np.max(values)) / 3:
|
||||
important_tokens.append((val, i))
|
||||
tokens = model_tfidf.get_feature_names_out()
|
||||
tokens = list(map(lambda x: (tokens[x[1]], x[0]), reversed(sorted(important_tokens))))
|
||||
|
||||
ans = cls.predict(tf_idfed)[0]
|
||||
text = resolve_text(tokens, text)
|
||||
return {'ans': ans, 'text': text}
|
||||
|
||||
cmap = LinearSegmentedColormap.from_list('rg',["w", "g"], N=512)
|
||||
|
||||
label2id = {
|
||||
'AAA(RU)': 0,
|
||||
'AA(RU)': 1,
|
||||
'A+(RU)': 2,
|
||||
'A(RU)': 3,
|
||||
'A-(RU)': 4,
|
||||
'BBB+(RU)': 5,
|
||||
'BBB(RU)': 6,
|
||||
'AA+(RU)': 7,
|
||||
'BBB-(RU)': 8,
|
||||
'AA-(RU)': 9,
|
||||
'BB+(RU)': 10,
|
||||
'BB-(RU)': 11,
|
||||
'B+(RU)': 12,
|
||||
'BB(RU)': 13,
|
||||
'B(RU)': 14,
|
||||
'B-(RU)': 15,
|
||||
'C(RU)': 16
|
||||
}
|
||||
id2label = {0: 'AAA(RU)',
|
||||
1: 'AA(RU)',
|
||||
2: 'A+(RU)',
|
||||
3: 'A(RU)',
|
||||
4: 'A-(RU)',
|
||||
5: 'BBB+(RU)',
|
||||
6: 'BBB(RU)',
|
||||
7: 'AA+(RU)',
|
||||
8: 'BBB-(RU)',
|
||||
9: 'AA-(RU)',
|
||||
10: 'BB+(RU)',
|
||||
11: 'BB-(RU)',
|
||||
12: 'B+(RU)',
|
||||
13: 'BB(RU)',
|
||||
14: 'B(RU)',
|
||||
15: 'B-(RU)',
|
||||
16: 'C(RU)'}
|
||||
|
||||
cmap = LinearSegmentedColormap.from_list('rg',["w", "g"], N=512)
|
||||
|
||||
|
||||
from math import inf
|
||||
from annoy import AnnoyIndex
|
||||
import numpy as np
|
||||
import pickle
|
||||
|
||||
|
||||
|
||||
def get_distance(emb1, emb2):
|
||||
emb2 /= np.sum(emb2**2)
|
||||
emb1 /= np.sum(emb1**2)
|
||||
return 1 / abs(np.dot(emb2-emb1, emb1-emb2))
|
||||
|
||||
|
||||
with open('new_embeddings.pickle', 'rb') as file:
|
||||
new_embeddings = pickle.load(file)
|
||||
|
||||
with open('annoy_labels.pickle', 'rb') as file:
|
||||
labels = pickle.load(file)
|
||||
|
||||
with open('n_labels.pickle', 'rb') as file:
|
||||
n_labels = pickle.load(file)
|
||||
|
||||
index = AnnoyIndex(768, 'angular')
|
||||
index.load('nearest.annoy')
|
||||
|
||||
|
||||
def get_nearest_value(embeddings):
|
||||
items = list(map(lambda x: (
|
||||
labels[x],
|
||||
get_distance(embeddings, new_embeddings[x]),
|
||||
list(n_labels)[x]
|
||||
),
|
||||
index.get_nns_by_vector(embeddings, 20)
|
||||
))
|
||||
weights = np.array([0 for _ in range(17)])
|
||||
refs = [[] for _ in range(17)]
|
||||
s = 0
|
||||
for item in items:
|
||||
if item[1] == inf:
|
||||
return id2label[item[0]], 100, [item[2]]
|
||||
s += item[1]
|
||||
weights[item[0]] += item[1]
|
||||
refs[item[0]].append(item[2])
|
||||
return id2label[np.argmax(weights)], (weights[np.argmax(weights)] / s) * 100, refs[np.argmax(weights)]
|
||||
|
||||
|
||||
def to_rgb(vals):
|
||||
return f'rgb({int(vals[0]*255)}, {int(vals[1]*255)}, {int(vals[2]*255)})'
|
||||
|
||||
def from_abs_to_rgb(min, max, value):
|
||||
return to_rgb(cmap((value - min)/ max))
|
||||
|
||||
|
||||
def get_nns_tokens(encoding, attrs, predicted_id):
|
||||
current_array = map(
|
||||
lambda x: (tokenizer.convert_ids_to_tokens(encoding['input_ids'][0][x[0]-5:x[0]+5]), x[1]),
|
||||
list(
|
||||
reversed(
|
||||
sorted(
|
||||
enumerate(
|
||||
attrs[0][predicted_id].numpy()
|
||||
),
|
||||
key=lambda x: x[1]
|
||||
)
|
||||
)
|
||||
)[0:10]
|
||||
)
|
||||
return list(current_array)
|
||||
|
||||
def get_description_interpreting(attrs, predicted_id):
|
||||
attrs = attrs.detach().numpy()
|
||||
positive_weights = attrs[0][predicted_id]
|
||||
negative_weights = [0 for _ in range(len(positive_weights))]
|
||||
for i in range(len(attrs[0])):
|
||||
if i == predicted_id: continue
|
||||
for j in range(len(attrs[0][i])):
|
||||
negative_weights[j] += attrs[0][i][j]
|
||||
for i in range(len(negative_weights)):
|
||||
negative_weights[i] /= len(attrs[0]) - 1
|
||||
|
||||
return {
|
||||
'positive_weights': (
|
||||
positive_weights,
|
||||
{
|
||||
'min': np.min(positive_weights),
|
||||
'max': np.max(positive_weights)
|
||||
}
|
||||
),
|
||||
'negative_weights': (
|
||||
negative_weights,
|
||||
{
|
||||
'min': min(negative_weights),
|
||||
'max': max(negative_weights)
|
||||
}
|
||||
)
|
||||
}
|
||||
|
||||
def transform_token_ids(func_data, token_ids, word):
|
||||
tokens = list(map(lambda x: tokenizer.convert_ids_to_tokens([x])[0].replace('##', ''), token({'text': clean(word)})['input_ids'][0]))
|
||||
weights = [func_data['positive_weights'][0][i] for i in token_ids]
|
||||
wts = []
|
||||
for i in range(len(weights)):
|
||||
if weights[i] > 0:
|
||||
#color = from_abs_to_rgb(func_data['positive_weights'][1]['min'], func_data['positive_weights'][1]['max'], weights[i])
|
||||
mn = max(func_data['positive_weights'][1]['min'], 0)
|
||||
mx = func_data['positive_weights'][1]['max']
|
||||
wts.append((weights[i] - mn)/ mx)
|
||||
#word = word.lower().replace(tokens[i], f'<span data-value="{(weights[i] - mn)/ mx}">{tokens[i]}</span>')
|
||||
try:
|
||||
if sum(wts) / len(wts) >= 0.2:
|
||||
return f'<span data-value={sum(wts) / len(wts)}>{word}</span>'
|
||||
except: pass
|
||||
return word
|
||||
|
||||
|
||||
def build_text(tokens, func_data, current_text):
|
||||
splitted_text = current_text.split()
|
||||
splitted_text_iterator = 0
|
||||
current_word = ''
|
||||
current_word_ids = []
|
||||
for i, token in enumerate(tokens):
|
||||
decoded = tokenizer.convert_ids_to_tokens([token])[0]
|
||||
if decoded == '[CLS]': continue
|
||||
if not len(current_word):
|
||||
current_word = decoded
|
||||
current_word_ids.append(i)
|
||||
elif decoded.startswith('##'):
|
||||
current_word += decoded[2:]
|
||||
current_word_ids.append(i)
|
||||
else:
|
||||
while clean(splitted_text[splitted_text_iterator]) != current_word:
|
||||
splitted_text_iterator += 1
|
||||
current_word = decoded
|
||||
splitted_text[splitted_text_iterator] = transform_token_ids(func_data, current_word_ids, splitted_text[splitted_text_iterator])
|
||||
current_word_ids = []
|
||||
return ' '.join(splitted_text)
|
||||
|
||||
def squad_pos_forward_func(inputs, token_type_ids=None, attention_mask=None, position=0):
|
||||
pred = predict(inputs.to(torch.long), token_type_ids.to(torch.long), attention_mask.to(torch.long))
|
||||
pred = pred[position]
|
||||
return pred.max(1).values
|
||||
|
||||
def predict_press_release(input_ids, token_type_ids, attention_mask):
|
||||
encoding = {
|
||||
'input_ids': input_ids.to(model.device),
|
||||
'token_type_ids': token_type_ids.to(model.device),
|
||||
'attention_mask': attention_mask.to(model.device)
|
||||
}
|
||||
outputs = model(**encoding)
|
||||
return outputs
|
||||
|
||||
|
||||
def clean(text):
|
||||
text = re.sub('[^а-яё ]', ' ', str(text).lower())
|
||||
text = re.sub(r" +", " ", text).strip()
|
||||
return text
|
||||
|
||||
|
||||
def get_description_interpreting(attrs):
|
||||
positive_weights = attrs
|
||||
return {
|
||||
'positive_weights': (
|
||||
positive_weights,
|
||||
{
|
||||
'min': np.min(positive_weights),
|
||||
'max': np.max(positive_weights)
|
||||
}
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
def predict(input_ids, token_type_ids, attention_mask):
|
||||
encoding = {
|
||||
'input_ids': input_ids.to(model.device),
|
||||
'token_type_ids': token_type_ids.to(model.device),
|
||||
'attention_mask': attention_mask.to(model.device)
|
||||
}
|
||||
outputs = model(**encoding)
|
||||
return outputs
|
||||
|
||||
|
||||
def batch_tokenize(text):
|
||||
splitted_text = text.split()
|
||||
current_batch = splitted_text[0]
|
||||
batches = []
|
||||
for word in splitted_text[1:]:
|
||||
if len(tokenizer(current_batch + ' ' + word)['input_ids']) < 512:
|
||||
current_batch += ' ' + word
|
||||
else:
|
||||
batches.append({
|
||||
'text': current_batch
|
||||
})
|
||||
current_batch = word
|
||||
return batches + [{'text': current_batch}]
|
||||
|
||||
|
||||
def token(text):
|
||||
return tokenizer(text['text'], padding=True, truncation=True, max_length=512, return_tensors='pt')
|
||||
|
||||
|
||||
def tfidf_classify(data):
|
||||
if not len(data.data): return ''
|
||||
data = list(map(lambda x: x['text'], batch_tokenize(clean(data.data))))
|
||||
predicted_labels = []
|
||||
predicted_text = ""
|
||||
for item in data:
|
||||
predicted_labels.append(process_classify(item)['ans'])
|
||||
ans = Counter(predicted_labels).most_common()[0][0]
|
||||
score = len(list(filter(lambda x: x == ans, predicted_labels))) / len(predicted_labels)
|
||||
ans = id2label[ans]
|
||||
return {'answer': ans, 'text': predicted_text, 'metric': score, 'extendingLabels': list(map(lambda x: id2label[x], predicted_labels))}
|
||||
|
||||
|
||||
def tfidf_embeddings(data):
|
||||
if not len(data.data): return ''
|
||||
data = list(map(lambda x: x['text'], batch_tokenize(clean(data.data))))
|
||||
predicted_labels = []
|
||||
predicted_text = ""
|
||||
for item in data:
|
||||
ans = process_embedding(item)
|
||||
predicted_labels.append(ans['ans'])
|
||||
predicted_text += ans['text'] + ' '
|
||||
ans = Counter(predicted_labels).most_common()[0][0]
|
||||
print(ans, predicted_text)
|
||||
return {'answer': id2label[ans], 'text': predicted_text}
|
||||
|
||||
|
||||
def bert_classify(data):
|
||||
data = clean(data)
|
||||
predicted = []
|
||||
text = ''
|
||||
batched = batch_tokenize(data)
|
||||
|
||||
for b in batched:
|
||||
print(len(predicted))
|
||||
embs = token(b)
|
||||
answer = predict_press_release(
|
||||
embs['input_ids'], embs['token_type_ids'], embs['attention_mask']
|
||||
).logits[0]
|
||||
answer = torch.softmax(answer, dim=-1).detach().numpy()
|
||||
answer_score = np.max(answer)
|
||||
predicted.append(
|
||||
[id2label[np.argmax(answer)],
|
||||
float(answer_score)]
|
||||
)
|
||||
ans = {'AA(RU)': [0]}
|
||||
for i in predicted:
|
||||
if i[0] not in ans.keys():
|
||||
ans.update({i[0]: [i[1]]})
|
||||
else:
|
||||
ans[i[0]].append(i[1])
|
||||
selected = 'AA(RU)'
|
||||
score = 0
|
||||
for candidate in ans.keys():
|
||||
if sum(ans[candidate]) / len(ans[candidate]) > score:
|
||||
score = sum(ans[candidate]) / len(ans[candidate])
|
||||
selected = candidate
|
||||
elif sum(ans[candidate]) / len(ans[candidate]) == score and len(ans[candidate]) > len(ans):
|
||||
selected = candidate
|
||||
return {
|
||||
'answer': selected,
|
||||
'text': text,
|
||||
'longAnswer': predicted,
|
||||
'metric': score
|
||||
}
|
||||
|
||||
|
||||
def bert_embeddings(data):
|
||||
data = clean(data)
|
||||
predicted = []
|
||||
text = ''
|
||||
batched = batch_tokenize(data)
|
||||
for b in batched:
|
||||
embs = token(b)
|
||||
predicted.append(np.argmax(predict_press_release(embs['input_ids'], embs['token_type_ids'], embs['attention_mask']).logits.detach().numpy()[0]))
|
||||
attrs = lig.attribute(embs['input_ids'], additional_forward_args=(embs['attention_mask'], embs['token_type_ids'], 0))
|
||||
attrs = np.array(list(map(lambda x: x.sum(), attrs[0])))
|
||||
descr = get_description_interpreting(attrs)
|
||||
text += build_text(embs['input_ids'][0], descr, b['text']) + ' '
|
||||
return {'answer': id2label[Counter(predicted).most_common()[0][0]], 'text': text}
|
||||
|
||||
|
||||
config = BertConfig.from_json_file("./akra_model/checkpoint/config.json")
|
||||
|
||||
model = AutoModelForSequenceClassification.from_pretrained(
|
||||
"./akra_model/checkpoint", config=config
|
||||
)
|
||||
tokenizer = AutoTokenizer.from_pretrained("cointegrated/rubert-tiny")
|
||||
|
||||
lig = LayerIntegratedGradients(squad_pos_forward_func, model.bert.embeddings)
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
class Predict(BaseModel):
|
||||
data: str
|
||||
|
||||
class ListPredict(BaseModel):
|
||||
data: list
|
||||
|
||||
|
||||
@app.post('/predict')
|
||||
def predict_(data: Predict):
|
||||
return bert_classify(data)
|
||||
|
||||
@app.post('/bert/process')
|
||||
def predict_f(data: Predict):
|
||||
return bert_classify(data)
|
||||
|
||||
|
||||
@app.get('/interpret')
|
||||
def interpret():
|
||||
pass
|
||||
|
||||
|
||||
@app.post('/tfidf/process')
|
||||
def tfidf_res(data: Predict):
|
||||
return tfidf_classify(data)
|
||||
|
||||
@app.post('/tfidf/batch')
|
||||
def tfidf_batch(data: ListPredict):
|
||||
res = []
|
||||
for item in data.data:
|
||||
res.append(tfidf_classify(Predict(data=item)))
|
||||
return res
|
||||
|
||||
@app.post('/bert/batch')
|
||||
def bert_batch(data: ListPredict):
|
||||
res = []
|
||||
for item in data:
|
||||
res.append(bert_classify({'data': item}))
|
||||
return res
|
||||
|
||||
@app.post('/bert/describe')
|
||||
def bert_describe(data: Predict):
|
||||
return bert_embeddings(data)
|
||||
|
||||
@app.post('/tfidf/describe')
|
||||
def tfidf_describe(data: Predict):
|
||||
return tfidf_embeddings(data)
|
||||
|
||||
|
||||
def get_nearest_service(data: Predict):
|
||||
data = clean(data.data)
|
||||
batched = batch_tokenize(data)
|
||||
res = []
|
||||
scores = {}
|
||||
for key in id2label.values():
|
||||
scores.update({key: []})
|
||||
for batch in batched:
|
||||
features = list(get_nearest_value(get_embs(batch['text'])))
|
||||
features[0] = features[0]
|
||||
features[1] /= 100
|
||||
scores[features[0]].append(features[1] if features[1] < 95 else 100)
|
||||
res.append(
|
||||
{
|
||||
'text': batch['text'],
|
||||
'features': features
|
||||
}
|
||||
)
|
||||
mx = 0
|
||||
label = 'AA(RU)'
|
||||
for key in scores.keys():
|
||||
try:
|
||||
if (sum(scores[key]) / len(scores[key])) > mx:
|
||||
label = key
|
||||
mx = (sum(scores[key]) / len(scores[key]))
|
||||
if (sum(scores[key]) / len(scores[key])) == mx:
|
||||
if len(scores[key]) > len(scores[label]):
|
||||
label = key
|
||||
except: pass
|
||||
return {'detailed': res, 'metric': mx, 'answer': label}
|
||||
|
||||
|
||||
@app.post('/nearest/nearest')
|
||||
def proccess_text(data: Predict):
|
||||
return get_nearest_service(data)
|
||||
|
||||
|
||||
@app.post('/catboost')
|
||||
def catboost_process(data: Predict):
|
||||
tfidf = tfidf_classify(data)
|
||||
bert = bert_classify(data)
|
||||
nearest = get_nearest_service(data)
|
||||
|
||||
inputs = [label2id[tfidf['answer']], tfidf['metric'], bert['metric'], label2id[bert['answer']], nearest['metric'], label2id[nearest['answer']]]
|
||||
catboost_answer = id2label[catboost.predict([inputs])[0][0]]
|
||||
return {
|
||||
'bert': bert,
|
||||
'tfidf': tfidf,
|
||||
'nearest': nearest,
|
||||
'total': catboost_answer
|
||||
}
|
177
nearest-search-train.ipynb
Normal file
177
nearest-search-train.ipynb
Normal file
|
@ -0,0 +1,177 @@
|
|||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 12,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n",
|
||||
"To disable this warning, you can either:\n",
|
||||
"\t- Avoid using `tokenizers` before the fork if possible\n",
|
||||
"\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n",
|
||||
"Requirement already satisfied: annoy in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (1.17.3)\n",
|
||||
"Requirement already satisfied: sentence_transformers in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (2.2.2)\n",
|
||||
"Requirement already satisfied: transformers<5.0.0,>=4.6.0 in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from sentence_transformers) (4.32.1)\n",
|
||||
"Requirement already satisfied: tqdm in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from sentence_transformers) (4.66.1)\n",
|
||||
"Requirement already satisfied: torch>=1.6.0 in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from sentence_transformers) (2.0.1)\n",
|
||||
"Requirement already satisfied: torchvision in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from sentence_transformers) (0.15.2)\n",
|
||||
"Requirement already satisfied: numpy in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from sentence_transformers) (1.25.2)\n",
|
||||
"Requirement already satisfied: scikit-learn in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from sentence_transformers) (1.3.0)\n",
|
||||
"Requirement already satisfied: scipy in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from sentence_transformers) (1.11.2)\n",
|
||||
"Requirement already satisfied: nltk in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from sentence_transformers) (3.8.1)\n",
|
||||
"Requirement already satisfied: sentencepiece in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from sentence_transformers) (0.1.99)\n",
|
||||
"Requirement already satisfied: huggingface-hub>=0.4.0 in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from sentence_transformers) (0.16.4)\n",
|
||||
"Requirement already satisfied: filelock in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from huggingface-hub>=0.4.0->sentence_transformers) (3.12.3)\n",
|
||||
"Requirement already satisfied: fsspec in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from huggingface-hub>=0.4.0->sentence_transformers) (2023.6.0)\n",
|
||||
"Requirement already satisfied: requests in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from huggingface-hub>=0.4.0->sentence_transformers) (2.31.0)\n",
|
||||
"Requirement already satisfied: pyyaml>=5.1 in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from huggingface-hub>=0.4.0->sentence_transformers) (6.0.1)\n",
|
||||
"Requirement already satisfied: typing-extensions>=3.7.4.3 in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from huggingface-hub>=0.4.0->sentence_transformers) (4.7.1)\n",
|
||||
"Requirement already satisfied: packaging>=20.9 in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from huggingface-hub>=0.4.0->sentence_transformers) (23.1)\n",
|
||||
"Requirement already satisfied: sympy in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from torch>=1.6.0->sentence_transformers) (1.12)\n",
|
||||
"Requirement already satisfied: networkx in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from torch>=1.6.0->sentence_transformers) (3.1)\n",
|
||||
"Requirement already satisfied: jinja2 in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from torch>=1.6.0->sentence_transformers) (3.1.2)\n",
|
||||
"Requirement already satisfied: regex!=2019.12.17 in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from transformers<5.0.0,>=4.6.0->sentence_transformers) (2023.8.8)\n",
|
||||
"Requirement already satisfied: tokenizers!=0.11.3,<0.14,>=0.11.1 in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from transformers<5.0.0,>=4.6.0->sentence_transformers) (0.13.3)\n",
|
||||
"Requirement already satisfied: safetensors>=0.3.1 in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from transformers<5.0.0,>=4.6.0->sentence_transformers) (0.3.3)\n",
|
||||
"Requirement already satisfied: click in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from nltk->sentence_transformers) (8.1.7)\n",
|
||||
"Requirement already satisfied: joblib in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from nltk->sentence_transformers) (1.3.2)\n",
|
||||
"Requirement already satisfied: threadpoolctl>=2.0.0 in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from scikit-learn->sentence_transformers) (3.2.0)\n",
|
||||
"Requirement already satisfied: pillow!=8.3.*,>=5.3.0 in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from torchvision->sentence_transformers) (10.0.0)\n",
|
||||
"Requirement already satisfied: MarkupSafe>=2.0 in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from jinja2->torch>=1.6.0->sentence_transformers) (2.1.3)\n",
|
||||
"Requirement already satisfied: charset-normalizer<4,>=2 in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from requests->huggingface-hub>=0.4.0->sentence_transformers) (3.2.0)\n",
|
||||
"Requirement already satisfied: idna<4,>=2.5 in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from requests->huggingface-hub>=0.4.0->sentence_transformers) (3.4)\n",
|
||||
"Requirement already satisfied: urllib3<3,>=1.21.1 in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from requests->huggingface-hub>=0.4.0->sentence_transformers) (2.0.4)\n",
|
||||
"Requirement already satisfied: certifi>=2017.4.17 in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from requests->huggingface-hub>=0.4.0->sentence_transformers) (2023.7.22)\n",
|
||||
"Requirement already satisfied: mpmath>=0.19 in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from sympy->torch>=1.6.0->sentence_transformers) (1.3.0)\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"!pip3 install annoy sentence_transformers"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
|
||||
" from .autonotebook import tqdm as notebook_tqdm\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from sentence_transformers import SentenceTransformer\n",
|
||||
"model = SentenceTransformer('sentence-transformers/LaBSE')\n",
|
||||
"\n",
|
||||
"def get_embs(text):\n",
|
||||
" embeddings = model.encode(text)\n",
|
||||
" return embeddings"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import pickle\n",
|
||||
"with open('n_labels.pickle', 'rb') as file:\n",
|
||||
" n_labels = pickle.load(file)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"with open('annoy_labels.pickle', 'rb') as file:\n",
|
||||
" labels = pickle.load(file)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"embs = []\n",
|
||||
"for text in n_labels:\n",
|
||||
" embs.append(get_embs(text))\n",
|
||||
" if len(embs) % 50 == 0:\n",
|
||||
" print(len(embs))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from annoy import AnnoyIndex\n",
|
||||
"index = AnnoyIndex(768, 'angular')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 10,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"for i, emb in enumerate(embs):\n",
|
||||
" index.add_item(i, emb)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 11,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"True"
|
||||
]
|
||||
},
|
||||
"execution_count": 11,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"index.build(20)"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.11.5"
|
||||
},
|
||||
"orig_nbformat": 4
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
123
req.txt
Normal file
123
req.txt
Normal file
|
@ -0,0 +1,123 @@
|
|||
accelerate==0.22.0
|
||||
aiohttp==3.8.5
|
||||
aiosignal==1.3.1
|
||||
annotated-types==0.5.0
|
||||
annoy==1.17.3
|
||||
anyio==3.7.1
|
||||
appnope==0.1.3
|
||||
asttokens==2.2.1
|
||||
async-timeout==4.0.3
|
||||
attrs==23.1.0
|
||||
backcall==0.2.0
|
||||
beautifulsoup4==4.12.2
|
||||
captum==0.6.0
|
||||
certifi==2023.7.22
|
||||
charset-normalizer==3.2.0
|
||||
click==8.1.7
|
||||
comm==0.1.4
|
||||
contourpy==1.1.0
|
||||
cycler==0.11.0
|
||||
datasets==2.14.4
|
||||
debugpy==1.6.7.post1
|
||||
decorator==5.1.1
|
||||
dill==0.3.7
|
||||
et-xmlfile==1.1.0
|
||||
evaluate==0.4.0
|
||||
exceptiongroup==1.1.3
|
||||
executing==1.2.0
|
||||
fastapi==0.103.1
|
||||
filelock==3.12.3
|
||||
fonttools==4.42.1
|
||||
frozenlist==1.4.0
|
||||
fsspec==2023.6.0
|
||||
h11==0.14.0
|
||||
httptools==0.6.0
|
||||
huggingface-hub==0.16.4
|
||||
icu==0.0.1
|
||||
idna==3.4
|
||||
ipykernel==6.25.1
|
||||
ipython==7.34.0
|
||||
jedi==0.19.0
|
||||
Jinja2==3.1.2
|
||||
joblib==1.3.2
|
||||
jupyter_client==8.3.1
|
||||
jupyter_core==5.3.1
|
||||
kiwisolver==1.4.5
|
||||
MarkupSafe==2.1.3
|
||||
matplotlib==3.7.2
|
||||
matplotlib-inline==0.1.6
|
||||
Morfessor==2.0.6
|
||||
mpmath==1.3.0
|
||||
multidict==6.0.4
|
||||
multiprocess==0.70.15
|
||||
nest-asyncio==1.5.7
|
||||
networkx==3.1
|
||||
nltk==3.8.1
|
||||
numpy==1.25.2
|
||||
openpyxl==3.1.2
|
||||
outcome==1.2.0
|
||||
packaging==23.1
|
||||
pandas==2.1.0
|
||||
parso==0.8.3
|
||||
pexpect==4.8.0
|
||||
pickleshare==0.7.5
|
||||
Pillow==10.0.0
|
||||
platformdirs==3.10.0
|
||||
polyglot==16.7.4
|
||||
prompt-toolkit==3.0.39
|
||||
psutil==5.9.5
|
||||
ptyprocess==0.7.0
|
||||
pure-eval==0.2.2
|
||||
pyarrow==13.0.0
|
||||
pydantic==2.3.0
|
||||
pydantic_core==2.6.3
|
||||
Pygments==2.16.1
|
||||
pymystem3==0.2.0
|
||||
pyparsing==3.0.9
|
||||
PySocks==1.7.1
|
||||
python-dateutil==2.8.2
|
||||
python-dotenv==1.0.0
|
||||
pytz==2023.3
|
||||
PyYAML==6.0.1
|
||||
pyzmq==25.1.1
|
||||
regex==2023.8.8
|
||||
requests==2.31.0
|
||||
responses==0.18.0
|
||||
safetensors==0.3.3
|
||||
scikit-learn==1.3.0
|
||||
scipy==1.11.2
|
||||
seaborn==0.12.2
|
||||
selenium==4.12.0
|
||||
sentence-transformers==2.2.2
|
||||
sentencepiece==0.1.99
|
||||
six==1.16.0
|
||||
sniffio==1.3.0
|
||||
sortedcontainers==2.4.0
|
||||
soupsieve==2.5
|
||||
stack-data==0.6.2
|
||||
starlette==0.27.0
|
||||
sympy==1.12
|
||||
threadpoolctl==3.2.0
|
||||
tokenizers==0.13.3
|
||||
torch==2.0.1
|
||||
torchvision==0.15.2
|
||||
tornado==6.3.3
|
||||
tqdm==4.66.1
|
||||
traitlets==5.9.0
|
||||
transformers==4.32.1
|
||||
transformers-interpret==0.10.0
|
||||
transliterate==1.10.2
|
||||
trio==0.22.2
|
||||
trio-websocket==0.10.3
|
||||
typing_extensions==4.7.1
|
||||
tzdata==2023.3
|
||||
undetected-chromedriver==3.5.3
|
||||
urllib3==2.0.4
|
||||
uvicorn==0.23.2
|
||||
uvloop==0.17.0
|
||||
watchfiles==0.20.0
|
||||
wcwidth==0.2.6
|
||||
websockets==11.0.3
|
||||
wsproto==1.2.0
|
||||
xxhash==3.3.0
|
||||
yarl==1.9.2
|
208
tfidf-train.ipynb
Normal file
208
tfidf-train.ipynb
Normal file
|
@ -0,0 +1,208 @@
|
|||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 10,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Requirement already satisfied: pymystem3 in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (0.2.0)\n",
|
||||
"Requirement already satisfied: pandas in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (2.1.0)\n",
|
||||
"Collecting sklearn\n",
|
||||
" Downloading sklearn-0.0.post9.tar.gz (3.6 kB)\n",
|
||||
" Installing build dependencies ... \u001b[?25ldone\n",
|
||||
"\u001b[?25h Getting requirements to build wheel ... \u001b[?25ldone\n",
|
||||
"\u001b[?25h Preparing metadata (pyproject.toml) ... \u001b[?25ldone\n",
|
||||
"\u001b[?25hRequirement already satisfied: requests in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from pymystem3) (2.31.0)\n",
|
||||
"Requirement already satisfied: numpy>=1.23.2 in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from pandas) (1.25.2)\n",
|
||||
"Requirement already satisfied: python-dateutil>=2.8.2 in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from pandas) (2.8.2)\n",
|
||||
"Requirement already satisfied: pytz>=2020.1 in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from pandas) (2023.3)\n",
|
||||
"Requirement already satisfied: tzdata>=2022.1 in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from pandas) (2023.3)\n",
|
||||
"Requirement already satisfied: six>=1.5 in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from python-dateutil>=2.8.2->pandas) (1.16.0)\n",
|
||||
"Requirement already satisfied: charset-normalizer<4,>=2 in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from requests->pymystem3) (3.2.0)\n",
|
||||
"Requirement already satisfied: idna<4,>=2.5 in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from requests->pymystem3) (3.4)\n",
|
||||
"Requirement already satisfied: urllib3<3,>=1.21.1 in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from requests->pymystem3) (2.0.4)\n",
|
||||
"Requirement already satisfied: certifi>=2017.4.17 in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from requests->pymystem3) (2023.7.22)\n",
|
||||
"Building wheels for collected packages: sklearn\n",
|
||||
" Building wheel for sklearn (pyproject.toml) ... \u001b[?25ldone\n",
|
||||
"\u001b[?25h Created wheel for sklearn: filename=sklearn-0.0.post9-py3-none-any.whl size=2952 sha256=de085da5188e0680130af47d37bf6a7803a4dbec121af8adf834ac3d03747231\n",
|
||||
" Stored in directory: /Users/ilya/Library/Caches/pip/wheels/ef/63/d1/f1671e1e93b7ef4d35df483f9b2485e6dd21941da9a92296fb\n",
|
||||
"Successfully built sklearn\n",
|
||||
"Installing collected packages: sklearn\n",
|
||||
"Successfully installed sklearn-0.0.post9\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"!pip3 install pymystem3 pandas sklearn"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import pickle\n",
|
||||
"with open('dss.pickle', 'rb') as file:\n",
|
||||
" data = pickle.load(file)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import re\n",
|
||||
"import pandas as pd\n",
|
||||
"from pymystem3 import Mystem\n",
|
||||
"\n",
|
||||
"def clean(text):\n",
|
||||
" text = re.sub('[^а-яё ]', ' ', str(text).lower())\n",
|
||||
" text = re.sub(r\" +\", \" \", text).strip()\n",
|
||||
" return text"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"texts = list(map(lambda x: clean(x['text']), data))\n",
|
||||
"labels = list(map(lambda x: x['label'], data))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"mstm = Mystem()\n",
|
||||
"normalized = [''.join(mstm.lemmatize(t)[:-1]) for t in texts]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import pandas as pd\n",
|
||||
"df = pd.DataFrame()\n",
|
||||
"df['text'] = texts\n",
|
||||
"df['norm'] = normalized\n",
|
||||
"df['label'] = labels"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from sklearn.model_selection import train_test_split\n",
|
||||
"\n",
|
||||
"train, test = train_test_split(df, test_size=0.1, random_state=42)\n",
|
||||
"valid, test = train_test_split(test, test_size=0.2, random_state=42)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from sklearn.feature_extraction.text import TfidfVectorizer\n",
|
||||
"\n",
|
||||
"model_tfidf = TfidfVectorizer(max_features=5000)\n",
|
||||
"\n",
|
||||
"train_tfidf = model_tfidf.fit_transform(train['norm'].values)\n",
|
||||
"valid_tfidf = model_tfidf.transform(valid['norm'].values)\n",
|
||||
"test_tfidf = model_tfidf.transform(test['norm'].values)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/html": [
|
||||
"<style>#sk-container-id-1 {color: black;}#sk-container-id-1 pre{padding: 0;}#sk-container-id-1 div.sk-toggleable {background-color: white;}#sk-container-id-1 label.sk-toggleable__label {cursor: pointer;display: block;width: 100%;margin-bottom: 0;padding: 0.3em;box-sizing: border-box;text-align: center;}#sk-container-id-1 label.sk-toggleable__label-arrow:before {content: \"▸\";float: left;margin-right: 0.25em;color: #696969;}#sk-container-id-1 label.sk-toggleable__label-arrow:hover:before {color: black;}#sk-container-id-1 div.sk-estimator:hover label.sk-toggleable__label-arrow:before {color: black;}#sk-container-id-1 div.sk-toggleable__content {max-height: 0;max-width: 0;overflow: hidden;text-align: left;background-color: #f0f8ff;}#sk-container-id-1 div.sk-toggleable__content pre {margin: 0.2em;color: black;border-radius: 0.25em;background-color: #f0f8ff;}#sk-container-id-1 input.sk-toggleable__control:checked~div.sk-toggleable__content {max-height: 200px;max-width: 100%;overflow: auto;}#sk-container-id-1 input.sk-toggleable__control:checked~label.sk-toggleable__label-arrow:before {content: \"▾\";}#sk-container-id-1 div.sk-estimator input.sk-toggleable__control:checked~label.sk-toggleable__label {background-color: #d4ebff;}#sk-container-id-1 div.sk-label input.sk-toggleable__control:checked~label.sk-toggleable__label {background-color: #d4ebff;}#sk-container-id-1 input.sk-hidden--visually {border: 0;clip: rect(1px 1px 1px 1px);clip: rect(1px, 1px, 1px, 1px);height: 1px;margin: -1px;overflow: hidden;padding: 0;position: absolute;width: 1px;}#sk-container-id-1 div.sk-estimator {font-family: monospace;background-color: #f0f8ff;border: 1px dotted black;border-radius: 0.25em;box-sizing: border-box;margin-bottom: 0.5em;}#sk-container-id-1 div.sk-estimator:hover {background-color: #d4ebff;}#sk-container-id-1 div.sk-parallel-item::after {content: \"\";width: 100%;border-bottom: 1px solid gray;flex-grow: 1;}#sk-container-id-1 div.sk-label:hover label.sk-toggleable__label {background-color: #d4ebff;}#sk-container-id-1 div.sk-serial::before {content: \"\";position: absolute;border-left: 1px solid gray;box-sizing: border-box;top: 0;bottom: 0;left: 50%;z-index: 0;}#sk-container-id-1 div.sk-serial {display: flex;flex-direction: column;align-items: center;background-color: white;padding-right: 0.2em;padding-left: 0.2em;position: relative;}#sk-container-id-1 div.sk-item {position: relative;z-index: 1;}#sk-container-id-1 div.sk-parallel {display: flex;align-items: stretch;justify-content: center;background-color: white;position: relative;}#sk-container-id-1 div.sk-item::before, #sk-container-id-1 div.sk-parallel-item::before {content: \"\";position: absolute;border-left: 1px solid gray;box-sizing: border-box;top: 0;bottom: 0;left: 50%;z-index: -1;}#sk-container-id-1 div.sk-parallel-item {display: flex;flex-direction: column;z-index: 1;position: relative;background-color: white;}#sk-container-id-1 div.sk-parallel-item:first-child::after {align-self: flex-end;width: 50%;}#sk-container-id-1 div.sk-parallel-item:last-child::after {align-self: flex-start;width: 50%;}#sk-container-id-1 div.sk-parallel-item:only-child::after {width: 0;}#sk-container-id-1 div.sk-dashed-wrapped {border: 1px dashed gray;margin: 0 0.4em 0.5em 0.4em;box-sizing: border-box;padding-bottom: 0.4em;background-color: white;}#sk-container-id-1 div.sk-label label {font-family: monospace;font-weight: bold;display: inline-block;line-height: 1.2em;}#sk-container-id-1 div.sk-label-container {text-align: center;}#sk-container-id-1 div.sk-container {/* jupyter's `normalize.less` sets `[hidden] { display: none; }` but bootstrap.min.css set `[hidden] { display: none !important; }` so we also need the `!important` here to be able to override the default hidden behavior on the sphinx rendered scikit-learn.org. See: https://github.com/scikit-learn/scikit-learn/issues/21755 */display: inline-block !important;position: relative;}#sk-container-id-1 div.sk-text-repr-fallback {display: none;}</style><div id=\"sk-container-id-1\" class=\"sk-top-container\"><div class=\"sk-text-repr-fallback\"><pre>RandomForestClassifier(random_state=42)</pre><b>In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook. <br />On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.</b></div><div class=\"sk-container\" hidden><div class=\"sk-item\"><div class=\"sk-estimator sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"sk-estimator-id-1\" type=\"checkbox\" checked><label for=\"sk-estimator-id-1\" class=\"sk-toggleable__label sk-toggleable__label-arrow\">RandomForestClassifier</label><div class=\"sk-toggleable__content\"><pre>RandomForestClassifier(random_state=42)</pre></div></div></div></div></div>"
|
||||
],
|
||||
"text/plain": [
|
||||
"RandomForestClassifier(random_state=42)"
|
||||
]
|
||||
},
|
||||
"execution_count": 8,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from sklearn.ensemble import RandomForestClassifier\n",
|
||||
"\n",
|
||||
"cls = RandomForestClassifier(random_state=42)\n",
|
||||
"cls.fit(train_tfidf, train['label'].values)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 9,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"0.629399514876379"
|
||||
]
|
||||
},
|
||||
"execution_count": 9,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from sklearn.metrics import f1_score\n",
|
||||
"predictions = cls.predict(test_tfidf)\n",
|
||||
"f1_score(predictions, test['label'].values, average='weighted')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.11.5"
|
||||
},
|
||||
"orig_nbformat": 4
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
Loading…
Reference in New Issue
Block a user