ml/catboost-train.ipynb
2023-09-10 08:43:58 +03:00

162 lines
6.3 KiB
Plaintext

{
"cells": [
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Requirement already satisfied: catboost in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (1.2.1)\n",
"Requirement already satisfied: graphviz in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from catboost) (0.20.1)\n",
"Requirement already satisfied: matplotlib in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from catboost) (3.7.2)\n",
"Requirement already satisfied: numpy>=1.16.0 in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from catboost) (1.25.2)\n",
"Requirement already satisfied: pandas>=0.24 in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from catboost) (2.1.0)\n",
"Requirement already satisfied: scipy in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from catboost) (1.11.2)\n",
"Requirement already satisfied: plotly in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from catboost) (5.16.1)\n",
"Requirement already satisfied: six in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from catboost) (1.16.0)\n",
"Requirement already satisfied: python-dateutil>=2.8.2 in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from pandas>=0.24->catboost) (2.8.2)\n",
"Requirement already satisfied: pytz>=2020.1 in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from pandas>=0.24->catboost) (2023.3)\n",
"Requirement already satisfied: tzdata>=2022.1 in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from pandas>=0.24->catboost) (2023.3)\n",
"Requirement already satisfied: contourpy>=1.0.1 in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from matplotlib->catboost) (1.1.0)\n",
"Requirement already satisfied: cycler>=0.10 in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from matplotlib->catboost) (0.11.0)\n",
"Requirement already satisfied: fonttools>=4.22.0 in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from matplotlib->catboost) (4.42.1)\n",
"Requirement already satisfied: kiwisolver>=1.0.1 in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from matplotlib->catboost) (1.4.5)\n",
"Requirement already satisfied: packaging>=20.0 in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from matplotlib->catboost) (23.1)\n",
"Requirement already satisfied: pillow>=6.2.0 in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from matplotlib->catboost) (10.0.0)\n",
"Requirement already satisfied: pyparsing<3.1,>=2.3.1 in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from matplotlib->catboost) (3.0.9)\n",
"Requirement already satisfied: tenacity>=6.2.0 in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from plotly->catboost) (8.2.3)\n"
]
}
],
"source": [
"!pip3 install catboost"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import pickle\n",
"\n",
"with open('cb.pickle', 'rb') as file:\n",
" cb_dataset = pickle.load(file)"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"from catboost import CatBoostClassifier"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"model = CatBoostClassifier(iterations=50000,\n",
" learning_rate=1e-2,\n",
" depth=8)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"label2id = {\n",
" 'AAA(RU)': 0,\n",
" 'AA(RU)': 1, \n",
" 'A+(RU)': 2,\n",
" 'A(RU)': 3,\n",
" 'A-(RU)': 4,\n",
" 'BBB+(RU)': 5,\n",
" 'BBB(RU)': 6, \n",
" 'AA+(RU)': 7,\n",
" 'BBB-(RU)': 8,\n",
" 'AA-(RU)': 9,\n",
" 'BB+(RU)': 10, \n",
" 'BB-(RU)': 11, \n",
" 'B+(RU)': 12,\n",
" 'BB(RU)': 13, \n",
" 'B(RU)': 14,\n",
" 'B-(RU)': 15, \n",
" 'C(RU)': 16\n",
"}"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"train = []\n",
"train_labels = []\n",
"for i in cb_dataset[0:6500]:\n",
" train.append([label2id[i['outs'][0]['answer']], i['outs'][0]['metric'], i['outs'][1]['metric'], label2id[i['outs'][1]['answer']], i['outs'][2]['metric'], label2id[i['outs'][2]['answer']]])\n",
" if not isinstance(i['label'], int):\n",
" train_labels.append(label2id[i['label'] + '(RU)'])\n",
" else:\n",
" train_labels.append(i['label'])"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"test = []\n",
"test_labels = []\n",
"for i in cb_dataset[6500:]:\n",
" test.append([label2id[i['outs'][0]['answer']], i['outs'][0]['metric'], i['outs'][1]['metric'], label2id[i['outs'][1]['answer']], i['outs'][2]['metric'], label2id[i['outs'][2]['answer']]])\n",
" if not isinstance(i['label'], int):\n",
" test_labels.append(label2id[i['label'] + '(RU)'])\n",
" else:\n",
" test_labels.append(i['label'])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"model.fit(train, train_labels)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.5"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}