ml/bert_train.ipynb
2023-09-10 08:43:58 +03:00

288 lines
16 KiB
Plaintext
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

{
"cells": [
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n",
"To disable this warning, you can either:\n",
"\t- Avoid using `tokenizers` before the fork if possible\n",
"\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n",
"Requirement already satisfied: transformers in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (4.32.1)\n",
"Requirement already satisfied: datasets in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (2.14.4)\n",
"Requirement already satisfied: pandas in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (2.1.0)\n",
"Requirement already satisfied: evaluate in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (0.4.0)\n",
"Requirement already satisfied: numpy in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (1.25.2)\n",
"Requirement already satisfied: filelock in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from transformers) (3.12.3)\n",
"Requirement already satisfied: huggingface-hub<1.0,>=0.15.1 in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from transformers) (0.16.4)\n",
"Requirement already satisfied: packaging>=20.0 in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from transformers) (23.1)\n",
"Requirement already satisfied: pyyaml>=5.1 in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from transformers) (6.0.1)\n",
"Requirement already satisfied: regex!=2019.12.17 in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from transformers) (2023.8.8)\n",
"Requirement already satisfied: requests in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from transformers) (2.31.0)\n",
"Requirement already satisfied: tokenizers!=0.11.3,<0.14,>=0.11.1 in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from transformers) (0.13.3)\n",
"Requirement already satisfied: safetensors>=0.3.1 in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from transformers) (0.3.3)\n",
"Requirement already satisfied: tqdm>=4.27 in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from transformers) (4.66.1)\n",
"Requirement already satisfied: pyarrow>=8.0.0 in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from datasets) (13.0.0)\n",
"Requirement already satisfied: dill<0.3.8,>=0.3.0 in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from datasets) (0.3.7)\n",
"Requirement already satisfied: xxhash in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from datasets) (3.3.0)\n",
"Requirement already satisfied: multiprocess in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from datasets) (0.70.15)\n",
"Requirement already satisfied: fsspec[http]>=2021.11.1 in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from datasets) (2023.6.0)\n",
"Requirement already satisfied: aiohttp in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from datasets) (3.8.5)\n",
"Requirement already satisfied: python-dateutil>=2.8.2 in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from pandas) (2.8.2)\n",
"Requirement already satisfied: pytz>=2020.1 in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from pandas) (2023.3)\n",
"Requirement already satisfied: tzdata>=2022.1 in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from pandas) (2023.3)\n",
"Requirement already satisfied: responses<0.19 in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from evaluate) (0.18.0)\n",
"Requirement already satisfied: attrs>=17.3.0 in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from aiohttp->datasets) (23.1.0)\n",
"Requirement already satisfied: charset-normalizer<4.0,>=2.0 in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from aiohttp->datasets) (3.2.0)\n",
"Requirement already satisfied: multidict<7.0,>=4.5 in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from aiohttp->datasets) (6.0.4)\n",
"Requirement already satisfied: async-timeout<5.0,>=4.0.0a3 in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from aiohttp->datasets) (4.0.3)\n",
"Requirement already satisfied: yarl<2.0,>=1.0 in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from aiohttp->datasets) (1.9.2)\n",
"Requirement already satisfied: frozenlist>=1.1.1 in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from aiohttp->datasets) (1.4.0)\n",
"Requirement already satisfied: aiosignal>=1.1.2 in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from aiohttp->datasets) (1.3.1)\n",
"Requirement already satisfied: typing-extensions>=3.7.4.3 in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from huggingface-hub<1.0,>=0.15.1->transformers) (4.7.1)\n",
"Requirement already satisfied: six>=1.5 in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from python-dateutil>=2.8.2->pandas) (1.16.0)\n",
"Requirement already satisfied: idna<4,>=2.5 in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from requests->transformers) (3.4)\n",
"Requirement already satisfied: urllib3<3,>=1.21.1 in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from requests->transformers) (2.0.4)\n",
"Requirement already satisfied: certifi>=2017.4.17 in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from requests->transformers) (2023.7.22)\n"
]
}
],
"source": [
"!pip3 install transformers datasets pandas evaluate numpy"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"label2id = {\n",
" 'AAA(RU)': 0,\n",
" 'AA(RU)': 1, \n",
" 'A+(RU)': 2,\n",
" 'A(RU)': 3,\n",
" 'A-(RU)': 4,\n",
" 'BBB+(RU)': 5,\n",
" 'BBB(RU)': 6, \n",
" 'AA+(RU)': 7,\n",
" 'BBB-(RU)': 8,\n",
" 'AA-(RU)': 9,\n",
" 'BB+(RU)': 10, \n",
" 'BB-(RU)': 11, \n",
" 'B+(RU)': 12,\n",
" 'BB(RU)': 13, \n",
" 'B(RU)': 14,\n",
" 'B-(RU)': 15, \n",
" 'C(RU)': 16\n",
"}\n",
"id2label = {0: 'AAA(RU)',\n",
" 1: 'AA(RU)',\n",
" 2: 'A+(RU)',\n",
" 3: 'A(RU)',\n",
" 4: 'A-(RU)',\n",
" 5: 'BBB+(RU)',\n",
" 6: 'BBB(RU)',\n",
" 7: 'AA+(RU)',\n",
" 8: 'BBB-(RU)',\n",
" 9: 'AA-(RU)',\n",
" 10: 'BB+(RU)',\n",
" 11: 'BB-(RU)',\n",
" 12: 'B+(RU)',\n",
" 13: 'BB(RU)',\n",
" 14: 'B(RU)',\n",
" 15: 'B-(RU)',\n",
" 16: 'C(RU)'}"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"import pickle\n",
"with open('dss.pickle', 'rb') as file:\n",
" data = pickle.load(file)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'text': 'Кредитный рейтинг АКБ Энергобанк ПАО далее Энергобанк Банк обусловлен удовлетворительным бизнеспрофилем в сочетании с сильной достаточностью капитала критическим рискпрофилем и удовлетворительной оценкой ликвидности и фондирования учитывающей концентрацию обязательств Банка на средствах крупнейших кредиторовОсновная деятельность Энергобанка сконцентрирована в Республике Татарстан далее РТ где он занимает устойчивые рыночные позиции занял шестое место по размеру активов на На российском банковском рынке Банк имеет невысокую долю на он занимал е место по величине собственного капитала и е место по размеру активовКлючевыми направлениями деятельности Банка являются кредитование предприятий агропромышленного комплекса строительства и торговли а также залоговое розничное кредитование Контролирующими собственниками Банка являются ИН Хайруллин и АН Хайруллин владеющие около акций через АО Эдельвейс КорпорейшнОценка бизнеспрофиля Банка отражает его относительно невысокую долю на российском рынке банковских услуг и выраженную региональную направленность его деятельности несмотря на планы по открытию отделений за пределами РТОперационный доход Банка характеризуется низким хотя и возрастающим уровнем диверсификации по итогам года значение индекса ХерфиндаляХиршмана составило и свидетельствовало о повышенной концентрации на кредитовании корпоративного сектора около операционного дохода Качество управления Банком оценивается АКРА как удовлетворительное и соответствующее среднему уровню в российском банковском секторе в целом Организационная структура Энергобанка соответствует масштабам и особенностям его бизнеса Структура собственности Банка прозрачна при этом отмечается связанность его операций с аффилированными с собственниками Банка компаниямиСтратегия Банка на период до конца года предполагает планомерное наращивание размера активов за',\n",
" 'label': 11}"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"data[0]"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
" from .autonotebook import tqdm as notebook_tqdm\n",
"Some weights of BertForSequenceClassification were not initialized from the model checkpoint at cointegrated/rubert-tiny and are newly initialized: ['classifier.weight', 'classifier.bias']\n",
"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
]
}
],
"source": [
"from transformers import AutoModelForSequenceClassification\n",
"model = AutoModelForSequenceClassification.from_pretrained(\n",
" \"cointegrated/rubert-tiny\", num_labels=len(id2label.keys()), id2label=id2label, label2id=label2id\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"from transformers import AutoTokenizer\n",
"tokenizer = AutoTokenizer.from_pretrained(\"cointegrated/rubert-tiny\")\n",
"def token(text):\n",
" return tokenizer(text['text'], padding=True, truncation=True, max_length=512, return_tensors='pt')"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Map: 100%|██████████| 11500/11500 [00:04<00:00, 2842.10 examples/s]\n",
"Map: 100%|██████████| 1845/1845 [00:00<00:00, 3125.02 examples/s]\n"
]
}
],
"source": [
"from datasets import Dataset\n",
"import pandas as pd\n",
"from random import shuffle\n",
"shuffle(data)\n",
"train = data[:11500]\n",
"test = data[11500:]\n",
"train = Dataset.from_pandas(pd.DataFrame(data=train))\n",
"test = Dataset.from_pandas(pd.DataFrame(data=test))\n",
"tokenized_train = train.map(token, batched=True)\n",
"tokenized_test = test.map(token, batched=True)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"import evaluate\n",
"import numpy as np\n",
"\n",
"f1 = evaluate.load(\"f1\")\n",
"\n",
"def compute_metrics(eval_pred):\n",
" predictions, labels = eval_pred\n",
" predictions = np.argmax(predictions, axis=1)\n",
" return f1.compute(predictions=predictions, references=labels, average='macro')"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Found safetensors installation, but --save_safetensors=False. Safetensors should be a preferred weights saving format due to security and performance reasons. If your model cannot be saved by safetensors please feel free to open an issue at https://github.com/huggingface/safetensors!\n",
"PyTorch: setting up devices\n",
"The default value for the training argument `--report_to` will change in v5 (from all installed integrations to none). In v5, you will need to use `--report_to all` to get the same behavior as now. You should start updating your code and make this info disappear :-).\n"
]
}
],
"source": [
"from transformers import TrainingArguments, Trainer, DataCollatorWithPadding\n",
"\n",
"data_collator = DataCollatorWithPadding(tokenizer=tokenizer)\n",
"training_args = TrainingArguments(\n",
" output_dir=\"akra_model\",\n",
" learning_rate=2e-5,\n",
" per_device_train_batch_size=16,\n",
" per_device_eval_batch_size=16,\n",
" num_train_epochs=10,\n",
" weight_decay=0.01,\n",
" evaluation_strategy=\"epoch\",\n",
" save_strategy=\"epoch\",\n",
" load_best_model_at_end=True,\n",
")\n",
"trainer = Trainer(\n",
" model=model,\n",
" args=training_args,\n",
" train_dataset=tokenized_train,\n",
" eval_dataset=tokenized_test,\n",
" tokenizer=tokenizer,\n",
" data_collator=data_collator,\n",
" compute_metrics=compute_metrics,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"trainer.train()"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.5"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}