ml/bert_train.ipynb

288 lines
16 KiB
Plaintext
Raw Permalink Normal View History

2023-09-10 08:43:58 +03:00
{
"cells": [
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n",
"To disable this warning, you can either:\n",
"\t- Avoid using `tokenizers` before the fork if possible\n",
"\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n",
"Requirement already satisfied: transformers in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (4.32.1)\n",
"Requirement already satisfied: datasets in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (2.14.4)\n",
"Requirement already satisfied: pandas in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (2.1.0)\n",
"Requirement already satisfied: evaluate in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (0.4.0)\n",
"Requirement already satisfied: numpy in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (1.25.2)\n",
"Requirement already satisfied: filelock in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from transformers) (3.12.3)\n",
"Requirement already satisfied: huggingface-hub<1.0,>=0.15.1 in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from transformers) (0.16.4)\n",
"Requirement already satisfied: packaging>=20.0 in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from transformers) (23.1)\n",
"Requirement already satisfied: pyyaml>=5.1 in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from transformers) (6.0.1)\n",
"Requirement already satisfied: regex!=2019.12.17 in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from transformers) (2023.8.8)\n",
"Requirement already satisfied: requests in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from transformers) (2.31.0)\n",
"Requirement already satisfied: tokenizers!=0.11.3,<0.14,>=0.11.1 in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from transformers) (0.13.3)\n",
"Requirement already satisfied: safetensors>=0.3.1 in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from transformers) (0.3.3)\n",
"Requirement already satisfied: tqdm>=4.27 in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from transformers) (4.66.1)\n",
"Requirement already satisfied: pyarrow>=8.0.0 in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from datasets) (13.0.0)\n",
"Requirement already satisfied: dill<0.3.8,>=0.3.0 in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from datasets) (0.3.7)\n",
"Requirement already satisfied: xxhash in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from datasets) (3.3.0)\n",
"Requirement already satisfied: multiprocess in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from datasets) (0.70.15)\n",
"Requirement already satisfied: fsspec[http]>=2021.11.1 in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from datasets) (2023.6.0)\n",
"Requirement already satisfied: aiohttp in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from datasets) (3.8.5)\n",
"Requirement already satisfied: python-dateutil>=2.8.2 in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from pandas) (2.8.2)\n",
"Requirement already satisfied: pytz>=2020.1 in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from pandas) (2023.3)\n",
"Requirement already satisfied: tzdata>=2022.1 in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from pandas) (2023.3)\n",
"Requirement already satisfied: responses<0.19 in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from evaluate) (0.18.0)\n",
"Requirement already satisfied: attrs>=17.3.0 in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from aiohttp->datasets) (23.1.0)\n",
"Requirement already satisfied: charset-normalizer<4.0,>=2.0 in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from aiohttp->datasets) (3.2.0)\n",
"Requirement already satisfied: multidict<7.0,>=4.5 in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from aiohttp->datasets) (6.0.4)\n",
"Requirement already satisfied: async-timeout<5.0,>=4.0.0a3 in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from aiohttp->datasets) (4.0.3)\n",
"Requirement already satisfied: yarl<2.0,>=1.0 in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from aiohttp->datasets) (1.9.2)\n",
"Requirement already satisfied: frozenlist>=1.1.1 in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from aiohttp->datasets) (1.4.0)\n",
"Requirement already satisfied: aiosignal>=1.1.2 in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from aiohttp->datasets) (1.3.1)\n",
"Requirement already satisfied: typing-extensions>=3.7.4.3 in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from huggingface-hub<1.0,>=0.15.1->transformers) (4.7.1)\n",
"Requirement already satisfied: six>=1.5 in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from python-dateutil>=2.8.2->pandas) (1.16.0)\n",
"Requirement already satisfied: idna<4,>=2.5 in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from requests->transformers) (3.4)\n",
"Requirement already satisfied: urllib3<3,>=1.21.1 in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from requests->transformers) (2.0.4)\n",
"Requirement already satisfied: certifi>=2017.4.17 in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from requests->transformers) (2023.7.22)\n"
]
}
],
"source": [
"!pip3 install transformers datasets pandas evaluate numpy"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"label2id = {\n",
" 'AAA(RU)': 0,\n",
" 'AA(RU)': 1, \n",
" 'A+(RU)': 2,\n",
" 'A(RU)': 3,\n",
" 'A-(RU)': 4,\n",
" 'BBB+(RU)': 5,\n",
" 'BBB(RU)': 6, \n",
" 'AA+(RU)': 7,\n",
" 'BBB-(RU)': 8,\n",
" 'AA-(RU)': 9,\n",
" 'BB+(RU)': 10, \n",
" 'BB-(RU)': 11, \n",
" 'B+(RU)': 12,\n",
" 'BB(RU)': 13, \n",
" 'B(RU)': 14,\n",
" 'B-(RU)': 15, \n",
" 'C(RU)': 16\n",
"}\n",
"id2label = {0: 'AAA(RU)',\n",
" 1: 'AA(RU)',\n",
" 2: 'A+(RU)',\n",
" 3: 'A(RU)',\n",
" 4: 'A-(RU)',\n",
" 5: 'BBB+(RU)',\n",
" 6: 'BBB(RU)',\n",
" 7: 'AA+(RU)',\n",
" 8: 'BBB-(RU)',\n",
" 9: 'AA-(RU)',\n",
" 10: 'BB+(RU)',\n",
" 11: 'BB-(RU)',\n",
" 12: 'B+(RU)',\n",
" 13: 'BB(RU)',\n",
" 14: 'B(RU)',\n",
" 15: 'B-(RU)',\n",
" 16: 'C(RU)'}"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"import pickle\n",
"with open('dss.pickle', 'rb') as file:\n",
" data = pickle.load(file)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'text': 'Кредитный рейтинг АКБ Энергобанк ПАО далее Энергобанк Банк обусловлен удовлетворительным бизнеспрофилем в сочетании с сильной достаточностью капитала критическим рискпрофилем и удовлетворительной оценкой ликвидности и фондирования учитывающей концентрацию обязательств Банка на средствах крупнейших кредиторовОсновная деятельность Энергобанка сконцентрирована в Республике Татарстан далее РТ где он занимает устойчивые рыночные позиции занял шестое место по размеру активов на На российском банковском рынке Банк имеет невысокую долю на он занимал е место по величине собственного капитала и е место по размеру активовКлючевыми направлениями деятельности Банка являются кредитование предприятий агропромышленного комплекса строительства и торговли а также залоговое розничное кредитование Контролирующими собственниками Банка являются ИН Хайруллин и АН Хайруллин владеющие около акций через АО Эдельвейс КорпорейшнОценка бизнеспрофиля Банка отражает его относительно невысокую долю на российском рынке банковских услуг и выраженную региональную направленность его деятельности несмотря на планы по открытию отделений за пределами РТОперационный доход Банка характеризуется низким хотя и возрастающим уровнем диверсификации по итогам года значение индекса ХерфиндаляХиршмана составило и свидетельствовало о повышенной концентрации на кредитовании корпоративного сектора около операционного дохода Качество управления Банком оценивается АКРА как удовлетворительное и соответствующее среднему уровню в российском банковском секторе в целом Организационная структура Энергобанка соответствует масштабам и особенностям его бизнеса Структура собственности Банка прозрачна при этом отмечается связанность его операций с аффилированными с собственниками Банка компаниямиСтратегия Банка на период до конца года предполагает планомерное наращивание размера активов за',\n",
" 'label': 11}"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"data[0]"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
" from .autonotebook import tqdm as notebook_tqdm\n",
"Some weights of BertForSequenceClassification were not initialized from the model checkpoint at cointegrated/rubert-tiny and are newly initialized: ['classifier.weight', 'classifier.bias']\n",
"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
]
}
],
"source": [
"from transformers import AutoModelForSequenceClassification\n",
"model = AutoModelForSequenceClassification.from_pretrained(\n",
" \"cointegrated/rubert-tiny\", num_labels=len(id2label.keys()), id2label=id2label, label2id=label2id\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"from transformers import AutoTokenizer\n",
"tokenizer = AutoTokenizer.from_pretrained(\"cointegrated/rubert-tiny\")\n",
"def token(text):\n",
" return tokenizer(text['text'], padding=True, truncation=True, max_length=512, return_tensors='pt')"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Map: 100%|██████████| 11500/11500 [00:04<00:00, 2842.10 examples/s]\n",
"Map: 100%|██████████| 1845/1845 [00:00<00:00, 3125.02 examples/s]\n"
]
}
],
"source": [
"from datasets import Dataset\n",
"import pandas as pd\n",
"from random import shuffle\n",
"shuffle(data)\n",
"train = data[:11500]\n",
"test = data[11500:]\n",
"train = Dataset.from_pandas(pd.DataFrame(data=train))\n",
"test = Dataset.from_pandas(pd.DataFrame(data=test))\n",
"tokenized_train = train.map(token, batched=True)\n",
"tokenized_test = test.map(token, batched=True)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"import evaluate\n",
"import numpy as np\n",
"\n",
"f1 = evaluate.load(\"f1\")\n",
"\n",
"def compute_metrics(eval_pred):\n",
" predictions, labels = eval_pred\n",
" predictions = np.argmax(predictions, axis=1)\n",
" return f1.compute(predictions=predictions, references=labels, average='macro')"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Found safetensors installation, but --save_safetensors=False. Safetensors should be a preferred weights saving format due to security and performance reasons. If your model cannot be saved by safetensors please feel free to open an issue at https://github.com/huggingface/safetensors!\n",
"PyTorch: setting up devices\n",
"The default value for the training argument `--report_to` will change in v5 (from all installed integrations to none). In v5, you will need to use `--report_to all` to get the same behavior as now. You should start updating your code and make this info disappear :-).\n"
]
}
],
"source": [
"from transformers import TrainingArguments, Trainer, DataCollatorWithPadding\n",
"\n",
"data_collator = DataCollatorWithPadding(tokenizer=tokenizer)\n",
"training_args = TrainingArguments(\n",
" output_dir=\"akra_model\",\n",
" learning_rate=2e-5,\n",
" per_device_train_batch_size=16,\n",
" per_device_eval_batch_size=16,\n",
" num_train_epochs=10,\n",
" weight_decay=0.01,\n",
" evaluation_strategy=\"epoch\",\n",
" save_strategy=\"epoch\",\n",
" load_best_model_at_end=True,\n",
")\n",
"trainer = Trainer(\n",
" model=model,\n",
" args=training_args,\n",
" train_dataset=tokenized_train,\n",
" eval_dataset=tokenized_test,\n",
" tokenizer=tokenizer,\n",
" data_collator=data_collator,\n",
" compute_metrics=compute_metrics,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"trainer.train()"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.5"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}