ml/tfidf-train.ipynb

209 lines
11 KiB
Plaintext
Raw Permalink Normal View History

2023-09-10 08:43:58 +03:00
{
"cells": [
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Requirement already satisfied: pymystem3 in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (0.2.0)\n",
"Requirement already satisfied: pandas in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (2.1.0)\n",
"Collecting sklearn\n",
" Downloading sklearn-0.0.post9.tar.gz (3.6 kB)\n",
" Installing build dependencies ... \u001b[?25ldone\n",
"\u001b[?25h Getting requirements to build wheel ... \u001b[?25ldone\n",
"\u001b[?25h Preparing metadata (pyproject.toml) ... \u001b[?25ldone\n",
"\u001b[?25hRequirement already satisfied: requests in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from pymystem3) (2.31.0)\n",
"Requirement already satisfied: numpy>=1.23.2 in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from pandas) (1.25.2)\n",
"Requirement already satisfied: python-dateutil>=2.8.2 in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from pandas) (2.8.2)\n",
"Requirement already satisfied: pytz>=2020.1 in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from pandas) (2023.3)\n",
"Requirement already satisfied: tzdata>=2022.1 in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from pandas) (2023.3)\n",
"Requirement already satisfied: six>=1.5 in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from python-dateutil>=2.8.2->pandas) (1.16.0)\n",
"Requirement already satisfied: charset-normalizer<4,>=2 in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from requests->pymystem3) (3.2.0)\n",
"Requirement already satisfied: idna<4,>=2.5 in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from requests->pymystem3) (3.4)\n",
"Requirement already satisfied: urllib3<3,>=1.21.1 in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from requests->pymystem3) (2.0.4)\n",
"Requirement already satisfied: certifi>=2017.4.17 in /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages (from requests->pymystem3) (2023.7.22)\n",
"Building wheels for collected packages: sklearn\n",
" Building wheel for sklearn (pyproject.toml) ... \u001b[?25ldone\n",
"\u001b[?25h Created wheel for sklearn: filename=sklearn-0.0.post9-py3-none-any.whl size=2952 sha256=de085da5188e0680130af47d37bf6a7803a4dbec121af8adf834ac3d03747231\n",
" Stored in directory: /Users/ilya/Library/Caches/pip/wheels/ef/63/d1/f1671e1e93b7ef4d35df483f9b2485e6dd21941da9a92296fb\n",
"Successfully built sklearn\n",
"Installing collected packages: sklearn\n",
"Successfully installed sklearn-0.0.post9\n"
]
}
],
"source": [
"!pip3 install pymystem3 pandas sklearn"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import pickle\n",
"with open('dss.pickle', 'rb') as file:\n",
" data = pickle.load(file)"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"import re\n",
"import pandas as pd\n",
"from pymystem3 import Mystem\n",
"\n",
"def clean(text):\n",
" text = re.sub('[^а-яё ]', ' ', str(text).lower())\n",
" text = re.sub(r\" +\", \" \", text).strip()\n",
" return text"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"texts = list(map(lambda x: clean(x['text']), data))\n",
"labels = list(map(lambda x: x['label'], data))"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"mstm = Mystem()\n",
"normalized = [''.join(mstm.lemmatize(t)[:-1]) for t in texts]"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"import pandas as pd\n",
"df = pd.DataFrame()\n",
"df['text'] = texts\n",
"df['norm'] = normalized\n",
"df['label'] = labels"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"from sklearn.model_selection import train_test_split\n",
"\n",
"train, test = train_test_split(df, test_size=0.1, random_state=42)\n",
"valid, test = train_test_split(test, test_size=0.2, random_state=42)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"from sklearn.feature_extraction.text import TfidfVectorizer\n",
"\n",
"model_tfidf = TfidfVectorizer(max_features=5000)\n",
"\n",
"train_tfidf = model_tfidf.fit_transform(train['norm'].values)\n",
"valid_tfidf = model_tfidf.transform(valid['norm'].values)\n",
"test_tfidf = model_tfidf.transform(test['norm'].values)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<style>#sk-container-id-1 {color: black;}#sk-container-id-1 pre{padding: 0;}#sk-container-id-1 div.sk-toggleable {background-color: white;}#sk-container-id-1 label.sk-toggleable__label {cursor: pointer;display: block;width: 100%;margin-bottom: 0;padding: 0.3em;box-sizing: border-box;text-align: center;}#sk-container-id-1 label.sk-toggleable__label-arrow:before {content: \"▸\";float: left;margin-right: 0.25em;color: #696969;}#sk-container-id-1 label.sk-toggleable__label-arrow:hover:before {color: black;}#sk-container-id-1 div.sk-estimator:hover label.sk-toggleable__label-arrow:before {color: black;}#sk-container-id-1 div.sk-toggleable__content {max-height: 0;max-width: 0;overflow: hidden;text-align: left;background-color: #f0f8ff;}#sk-container-id-1 div.sk-toggleable__content pre {margin: 0.2em;color: black;border-radius: 0.25em;background-color: #f0f8ff;}#sk-container-id-1 input.sk-toggleable__control:checked~div.sk-toggleable__content {max-height: 200px;max-width: 100%;overflow: auto;}#sk-container-id-1 input.sk-toggleable__control:checked~label.sk-toggleable__label-arrow:before {content: \"▾\";}#sk-container-id-1 div.sk-estimator input.sk-toggleable__control:checked~label.sk-toggleable__label {background-color: #d4ebff;}#sk-container-id-1 div.sk-label input.sk-toggleable__control:checked~label.sk-toggleable__label {background-color: #d4ebff;}#sk-container-id-1 input.sk-hidden--visually {border: 0;clip: rect(1px 1px 1px 1px);clip: rect(1px, 1px, 1px, 1px);height: 1px;margin: -1px;overflow: hidden;padding: 0;position: absolute;width: 1px;}#sk-container-id-1 div.sk-estimator {font-family: monospace;background-color: #f0f8ff;border: 1px dotted black;border-radius: 0.25em;box-sizing: border-box;margin-bottom: 0.5em;}#sk-container-id-1 div.sk-estimator:hover {background-color: #d4ebff;}#sk-container-id-1 div.sk-parallel-item::after {content: \"\";width: 100%;border-bottom: 1px solid gray;flex-grow: 1;}#sk-container-id-1 div.sk-label:hover label.sk-toggleable__label {background-color: #d4ebff;}#sk-container-id-1 div.sk-serial::before {content: \"\";position: absolute;border-left: 1px solid gray;box-sizing: border-box;top: 0;bottom: 0;left: 50%;z-index: 0;}#sk-container-id-1 div.sk-serial {display: flex;flex-direction: column;align-items: center;background-color: white;padding-right: 0.2em;padding-left: 0.2em;position: relative;}#sk-container-id-1 div.sk-item {position: relative;z-index: 1;}#sk-container-id-1 div.sk-parallel {display: flex;align-items: stretch;justify-content: center;background-color: white;position: relative;}#sk-container-id-1 div.sk-item::before, #sk-container-id-1 div.sk-parallel-item::before {content: \"\";position: absolute;border-left: 1px solid gray;box-sizing: border-box;top: 0;bottom: 0;left: 50%;z-index: -1;}#sk-container-id-1 div.sk-parallel-item {display: flex;flex-direction: column;z-index: 1;position: relative;background-color: white;}#sk-container-id-1 div.sk-parallel-item:first-child::after {align-self: flex-end;width: 50%;}#sk-container-id-1 div.sk-parallel-item:last-child::after {align-self: flex-start;width: 50%;}#sk-container-id-1 div.sk-parallel-item:only-child::after {width: 0;}#sk-container-id-1 div.sk-dashed-wrapped {border: 1px dashed gray;margin: 0 0.4em 0.5em 0.4em;box-sizing: border-box;padding-bottom: 0.4em;background-color: white;}#sk-container-id-1 div.sk-label label {font-family: monospace;font-weight: bold;display: inline-block;line-height: 1.2em;}#sk-container-id-1 div.sk-label-container {text-align: center;}#sk-container-id-1 div.sk-container {/* jupyter's `normalize.less` sets `[hidden] { display: none; }` but bootstrap.min.css set `[hidden] { display: none !important; }` so we also need the `!important` here to be able to override the default hidden behavior on the sphinx rendered scikit-learn.org. See: https://github.com/scikit-learn/scikit-learn/issues/21755 */display: inline-block !important;position: relative;}#sk-container-id-1 div.sk-text-repr-fallback {display: none;}</style><div id=\"sk-container-id-1\" class=\"sk-top-container\"><div class=\"sk-text-r
],
"text/plain": [
"RandomForestClassifier(random_state=42)"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from sklearn.ensemble import RandomForestClassifier\n",
"\n",
"cls = RandomForestClassifier(random_state=42)\n",
"cls.fit(train_tfidf, train['label'].values)"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"0.629399514876379"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from sklearn.metrics import f1_score\n",
"predictions = cls.predict(test_tfidf)\n",
"f1_score(predictions, test['label'].values, average='weighted')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.5"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}