{ "cells": [ { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Requirement already satisfied: catboost in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (1.2.1)\n", "Requirement already satisfied: graphviz in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from catboost) (0.20.1)\n", "Requirement already satisfied: matplotlib in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from catboost) (3.7.2)\n", "Requirement already satisfied: numpy>=1.16.0 in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from catboost) (1.25.2)\n", "Requirement already satisfied: pandas>=0.24 in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from catboost) (2.1.0)\n", "Requirement already satisfied: scipy in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from catboost) (1.11.2)\n", "Requirement already satisfied: plotly in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from catboost) (5.16.1)\n", "Requirement already satisfied: six in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from catboost) (1.16.0)\n", "Requirement already satisfied: python-dateutil>=2.8.2 in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from pandas>=0.24->catboost) (2.8.2)\n", "Requirement already satisfied: pytz>=2020.1 in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from pandas>=0.24->catboost) (2023.3)\n", "Requirement already satisfied: tzdata>=2022.1 in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from pandas>=0.24->catboost) (2023.3)\n", "Requirement already satisfied: contourpy>=1.0.1 in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from matplotlib->catboost) (1.1.0)\n", "Requirement already satisfied: cycler>=0.10 in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from matplotlib->catboost) (0.11.0)\n", "Requirement already satisfied: fonttools>=4.22.0 in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from matplotlib->catboost) (4.42.1)\n", "Requirement already satisfied: kiwisolver>=1.0.1 in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from matplotlib->catboost) (1.4.5)\n", "Requirement already satisfied: packaging>=20.0 in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from matplotlib->catboost) (23.1)\n", "Requirement already satisfied: pillow>=6.2.0 in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from matplotlib->catboost) (10.0.0)\n", "Requirement already satisfied: pyparsing<3.1,>=2.3.1 in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from matplotlib->catboost) (3.0.9)\n", "Requirement already satisfied: tenacity>=6.2.0 in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from plotly->catboost) (8.2.3)\n" ] } ], "source": [ "!pip3 install catboost" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import pickle\n", "\n", "with open('cb.pickle', 'rb') as file:\n", " cb_dataset = pickle.load(file)" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "from catboost import CatBoostClassifier" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "model = CatBoostClassifier(iterations=50000,\n", " learning_rate=1e-2,\n", " depth=8)" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "label2id = {\n", " 'AAA(RU)': 0,\n", " 'AA(RU)': 1, \n", " 'A+(RU)': 2,\n", " 'A(RU)': 3,\n", " 'A-(RU)': 4,\n", " 'BBB+(RU)': 5,\n", " 'BBB(RU)': 6, \n", " 'AA+(RU)': 7,\n", " 'BBB-(RU)': 8,\n", " 'AA-(RU)': 9,\n", " 'BB+(RU)': 10, \n", " 'BB-(RU)': 11, \n", " 'B+(RU)': 12,\n", " 'BB(RU)': 13, \n", " 'B(RU)': 14,\n", " 'B-(RU)': 15, \n", " 'C(RU)': 16\n", "}" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "train = []\n", "train_labels = []\n", "for i in cb_dataset[0:6500]:\n", " train.append([label2id[i['outs'][0]['answer']], i['outs'][0]['metric'], i['outs'][1]['metric'], label2id[i['outs'][1]['answer']], i['outs'][2]['metric'], label2id[i['outs'][2]['answer']]])\n", " if not isinstance(i['label'], int):\n", " train_labels.append(label2id[i['label'] + '(RU)'])\n", " else:\n", " train_labels.append(i['label'])" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "test = []\n", "test_labels = []\n", "for i in cb_dataset[6500:]:\n", " test.append([label2id[i['outs'][0]['answer']], i['outs'][0]['metric'], i['outs'][1]['metric'], label2id[i['outs'][1]['answer']], i['outs'][2]['metric'], label2id[i['outs'][2]['answer']]])\n", " if not isinstance(i['label'], int):\n", " test_labels.append(label2id[i['label'] + '(RU)'])\n", " else:\n", " test_labels.append(i['label'])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "model.fit(train, train_labels)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.5" }, "orig_nbformat": 4 }, "nbformat": 4, "nbformat_minor": 2 }