{ "cells": [ { "cell_type": "code", "execution_count": 6, "id": "a371550d-cbd7-4435-a41b-3fb35cccaa74", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Archive: test_dataset/Тестовый датасет (2).zip\n", " inflating: test_dataset/2.docx \n", " inflating: test_dataset/3.docx \n", " inflating: test_dataset/4.docx \n", " inflating: test_dataset/5.docx \n", " inflating: test_dataset/6.docx \n", " inflating: test_dataset/7.docx \n", " inflating: test_dataset/8.docx \n", " inflating: test_dataset/9.docx \n", " inflating: test_dataset/10.docx \n", " inflating: test_dataset/Название_команды.csv \n", " inflating: test_dataset/Пояснения к валидации.docx \n", " inflating: test_dataset/1.docx \n" ] } ], "source": [ "!unzip test_dataset/Тестовый\\ датасет\\ \\(2\\).zip -d test_dataset/" ] }, { "cell_type": "code", "execution_count": 1, "id": "969b75f5-a211-44d2-b704-c3c4c90f4401", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/home/ubuntu/venv/lib/python3.8/site-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", " from .autonotebook import tqdm as notebook_tqdm\n" ] } ], "source": [ "import torch\n", "import numpy as np\n", "from transformers import AutoTokenizer, AutoModel\n", "from tqdm import tqdm\n", "tqdm.pandas()\n", "\n", "import matplotlib.pyplot as plt\n", "import seaborn as sns\n", "sns.set()\n", "import numpy as np\n", "\n", "import pandas as pd\n", "import docx\n", "import os\n", "import re\n", "\n", "device = 'cuda'" ] }, { "cell_type": "code", "execution_count": 3, "id": "358aad5c-210b-450f-a273-8b759ed88b9c", "metadata": {}, "outputs": [], "source": [ "#word preprocessing functions\n", "\n", "import nltk\n", "from nltk.tokenize import word_tokenize\n", "\n", "from nltk.corpus import stopwords\n", "from nltk.stem import SnowballStemmer\n", "from nltk.tokenize import word_tokenize\n", "\n", "russian_stopwords = stopwords.words(\"russian\")\n", "\n", "russian_stopwords.append('российской')\n", "russian_stopwords.append('федерации')\n", "russian_stopwords.append('федерального')\n", "russian_stopwords.append('настоящих')\n", "russian_stopwords.append('соответствии')\n", "russian_stopwords.append('также')\n", "russian_stopwords.append('рф')\n", "russian_stopwords.append('ред')\n", "\n", "russian_stopwords.append('субсидии')\n", "russian_stopwords.append('предоставления')\n", "\n", "\n", "def lowercase(text):\n", " return str(text).lower()\n", "\n", "def clean_symb(text):\n", " return re.sub(r'[^\\w]', ' ', text)\n", "\n", "def clear_token(text):\n", " return word_tokenize(text)\n", "\n", "def clean_stopwords(token):\n", " return ' '.join([i for i in token.split(' ') if i not in russian_stopwords])\n", "\n", "def clean_stem(token):\n", " return [st.stem(i) for i in token]\n", "\n", "### combo 3 prev\n", "def make_clean(s):\n", " return ' '.join(clean_stem(clean_stopwords(clear_token(s))))" ] }, { "cell_type": "code", "execution_count": 4, "id": "4d884c5c-cb0f-44ef-8127-175497c7da47", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "11" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "#find all docs\n", "\n", "data_path = './test_dataset/'\n", "documents = os.listdir(data_path)\n", "documents = [data_path+d for d in documents]\n", "len(documents)" ] }, { "cell_type": "code", "execution_count": 5, "id": "e009af9e-97bf-4e9f-928b-c6bd50424b4b", "metadata": {}, "outputs": [], "source": [ "#load and parse doc functions\n", "\n", "from nltk.corpus import stopwords\n", "from nltk.stem import SnowballStemmer\n", "from nltk.tokenize import word_tokenize\n", "\n", "\n", "def get_text(filename):\n", " doc = docx.Document(filename)\n", " fullText = ''\n", " for para in doc.paragraphs:\n", " for run in para.runs:\n", " fullText+=run.text\n", " fullText+='\\n'\n", " return fullText\n", "\n", "def split_text(text):\n", " texts, groups = [],[]\n", " regt = re.findall(r\"{(.*?)}(.*?){(.*?)}\",text.replace('\\n',''))\n", " for t in regt:\n", " if t[0]==t[-1]:\n", " texts.append(t[1])\n", " groups.append(int(t[0]))\n", " else:\n", " print(t)\n", " \n", " return texts, groups" ] }, { "cell_type": "code", "execution_count": 6, "id": "c1ee24fe-988f-4d9e-97e6-522bf2e41798", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "('66', '18. ', '67')\n", "('67', '', '68')\n", "('68', '', '69')\n", "('69', '', '70')\n", "('Q = (К1 x ГС1) + (К2 x ГС1),где:{42', 'К1 - кандидаты на назначение государственных стипендий, являющиеся молодыми (до 35 лет включительно) творческими деятелями в области культуры и искусства;', '42')\n" ] } ], "source": [ "#load data\n", "\n", "all_text, all_groups, doc_paths, doc_names = [],[],[],[]\n", "for d in documents:\n", " if 'ipynb' not in d:\n", " text = get_text(d)\n", " texts,groups = split_text(text)\n", " all_text.extend(texts)\n", " all_groups.extend(groups)\n", " doc_paths.extend([d for a in range(len(texts))])\n", " doc_names.extend([d.split('/')[-1] for a in range(len(texts))])" ] }, { "cell_type": "code", "execution_count": 41, "id": "f3d481a5-b3e4-4367-bfa9-c41c43c05a7b", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(846, 7)" ] }, "execution_count": 41, "metadata": {}, "output_type": "execute_result" } ], "source": [ "#load documents\n", "#apply preprocessing to documents\n", "\n", "df = pd.DataFrame([doc_paths, doc_names,all_text, all_groups]).T\n", "df.columns = ['path','doc','text','id']\n", "df['r_text']='r'\n", "df['r_text'] = df.text.apply(lowercase)\n", "df['r_text'] = df.r_text.apply(clean_symb)\n", "df['r_text'] = df.r_text.apply(lambda x:''.join([a for a in x if not a.isdigit()]))\n", "df['r_text'] = df.r_text.apply(lambda x:' '.join([a for a in x.split(' ') if len(a)>1]))\n", "df['r_text'] = df.r_text.apply(clean_stopwords)\n", "df['text_size'] = df['text'].apply(lambda x: len(x.strip()))\n", "df['is_text'] = df.r_text.apply(lambda x:x.strip().isdigit())\n", "df = df[~df.is_text]\n", "df = df[df['text_size']>5]\n", "df = df.reset_index(drop=True)\n", "df.shape" ] }, { "cell_type": "code", "execution_count": 42, "id": "2e3997a7-7e66-4077-8827-007b64f663f5", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "правил 175\n", "организации 155\n", "предоставлении 123\n", "соглашения 115\n", "конкурса 90\n", "министерством 85\n", "средств 82\n", "финансового 79\n", "числе 76\n", "заявок 74\n", "dtype: int64" ] }, "execution_count": 42, "metadata": {}, "output_type": "execute_result" } ], "source": [ "#show most popular words in datset\n", "\n", "all_words = np.hstack([np.array(a) for a in df.r_text.apply(lambda x:[a for a in x.split(' ') if len(a)>0]).values])\n", "pd.Series(all_words).value_counts()[:10]" ] }, { "cell_type": "code", "execution_count": 43, "id": "a3b50eb7-e8b4-4503-92dd-9f0c50b7c2e9", "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "import torch\n", "from transformers import BertTokenizer, BertForSequenceClassification\n", "from torch.utils.data import Dataset, DataLoader\n", "from transformers import AdamW, get_linear_schedule_with_warmup\n", "\n", "from bert_dataset import CustomDataset" ] }, { "cell_type": "code", "execution_count": 11, "id": "b227db93-5193-4851-95dc-1b89277e3c2d", "metadata": {}, "outputs": [], "source": [ "#load model\n", "model = torch.load('bert_opossum_best_0.6753.pt')\n", "model.to(device)\n", "pass\n", "\n", "#load tokenizer\n", "tokenizer = BertTokenizer.from_pretrained('sberbank-ai/ruBert-base')" ] }, { "cell_type": "code", "execution_count": 44, "id": "e1a0c200-ec37-4058-96d1-b625f839e922", "metadata": {}, "outputs": [], "source": [ "#predict function\n", "def predict(text):\n", " encoding = tokenizer.encode_plus(\n", " text,\n", " add_special_tokens=True,\n", " max_length=512,\n", " return_token_type_ids=False,\n", " truncation=True,\n", " padding='max_length',\n", " return_attention_mask=True,\n", " return_tensors='pt',\n", " )\n", "\n", " out = {\n", " 'text': text,\n", " 'input_ids': encoding['input_ids'].flatten(),\n", " 'attention_mask': encoding['attention_mask'].flatten()\n", " }\n", " \n", " input_ids = out[\"input_ids\"].to(device)\n", " attention_mask = out[\"attention_mask\"].to(device)\n", "\n", " outputs = model(\n", " input_ids=input_ids.unsqueeze(0),\n", " attention_mask=attention_mask.unsqueeze(0)\n", " )\n", "\n", " prediction = torch.argmax(outputs.logits, dim=1).cpu().numpy()[0]\n", " return prediction" ] }, { "cell_type": "code", "execution_count": 49, "id": "b4095596-35ac-44e0-b6bb-dfc10bd712db", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "100%|█████████████████████████████████████████| 846/846 [00:29<00:00, 28.49it/s]\n" ] } ], "source": [ "#apply predict funtcion to text\n", "df['class'] = df.r_text.progress_apply(predict)" ] }, { "cell_type": "code", "execution_count": 50, "id": "0392b3c9-7e93-47cf-8729-98be9dbb9285", "metadata": {}, "outputs": [], "source": [ "#get file_id\n", "df['file_id'] = df.doc.apply(lambda x:int(x.split('.')[0]))" ] }, { "cell_type": "code", "execution_count": 80, "id": "3a8c8d64-4f5e-466b-a48f-79a18cf7de09", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", " | file_id | \n", "id | \n", "class | \n", "
---|---|---|---|
0 | \n", "1 | \n", "1 | \n", "2 | \n", "
1 | \n", "1 | \n", "2 | \n", "2 | \n", "
2 | \n", "1 | \n", "3 | \n", "2 | \n", "
3 | \n", "1 | \n", "4 | \n", "3 | \n", "
4 | \n", "1 | \n", "5 | \n", "7 | \n", "
... | \n", "... | \n", "... | \n", "... | \n", "
847 | \n", "10 | \n", "67 | \n", "38 | \n", "
848 | \n", "10 | \n", "68 | \n", "38 | \n", "
849 | \n", "10 | \n", "69 | \n", "38 | \n", "
850 | \n", "10 | \n", "70 | \n", "38 | \n", "
851 | \n", "10 | \n", "71 | \n", "7 | \n", "
852 rows × 3 columns
\n", "