mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-10-31 07:57:35 +03:00 
			
		
		
		
	evaluating on dev set during training
This commit is contained in:
		
							parent
							
								
									b6d788064a
								
							
						
					
					
						commit
						3b81b00954
					
				|  | @ -70,12 +70,10 @@ def is_dev(file_name): | ||||||
|     return file_name.endswith("3.txt") |     return file_name.endswith("3.txt") | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def evaluate(predictions, golds): | def evaluate(predictions, golds, to_print=True): | ||||||
|     if len(predictions) != len(golds): |     if len(predictions) != len(golds): | ||||||
|         raise ValueError("predictions and gold entities should have the same length") |         raise ValueError("predictions and gold entities should have the same length") | ||||||
| 
 | 
 | ||||||
|     print("Evaluating", len(golds), "entities") |  | ||||||
| 
 |  | ||||||
|     tp = 0 |     tp = 0 | ||||||
|     fp = 0 |     fp = 0 | ||||||
|     fn = 0 |     fn = 0 | ||||||
|  | @ -89,17 +87,22 @@ def evaluate(predictions, golds): | ||||||
|         else: |         else: | ||||||
|             fp += 1 |             fp += 1 | ||||||
| 
 | 
 | ||||||
|  |     if to_print: | ||||||
|  |         print("Evaluating", len(golds), "entities") | ||||||
|         print("tp", tp) |         print("tp", tp) | ||||||
|         print("fp", fp) |         print("fp", fp) | ||||||
|         print("fn", fn) |         print("fn", fn) | ||||||
| 
 | 
 | ||||||
|     precision = tp / (tp + fp + 0.0000001) |     precision = 100 * tp / (tp + fp + 0.0000001) | ||||||
|     recall = tp / (tp + fn + 0.0000001) |     recall = 100 * tp / (tp + fn + 0.0000001) | ||||||
|     fscore = 2 * recall * precision / (recall + precision + 0.0000001) |     fscore = 2 * recall * precision / (recall + precision + 0.0000001) | ||||||
| 
 | 
 | ||||||
|     print("precision", round(100 * precision, 1), "%") |     if to_print: | ||||||
|     print("recall", round(100 * recall, 1), "%") |         print("precision", round(precision, 1), "%") | ||||||
|     print("Fscore", round(100 * fscore, 1), "%") |         print("recall", round(recall, 1), "%") | ||||||
|  |         print("Fscore", round(fscore, 1), "%") | ||||||
|  | 
 | ||||||
|  |     return precision, recall, fscore | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def _prepare_pipeline(nlp, kb): | def _prepare_pipeline(nlp, kb): | ||||||
|  |  | ||||||
|  | @ -5,6 +5,7 @@ import os | ||||||
| import datetime | import datetime | ||||||
| from os import listdir | from os import listdir | ||||||
| import numpy as np | import numpy as np | ||||||
|  | from random import shuffle | ||||||
| 
 | 
 | ||||||
| from examples.pipeline.wiki_entity_linking import run_el, training_set_creator, kb_creator | from examples.pipeline.wiki_entity_linking import run_el, training_set_creator, kb_creator | ||||||
| 
 | 
 | ||||||
|  | @ -16,6 +17,8 @@ from thinc.t2v import Pooling, sum_pool, mean_pool | ||||||
| from thinc.t2t import ExtractWindow, ParametricAttention | from thinc.t2t import ExtractWindow, ParametricAttention | ||||||
| from thinc.misc import Residual, LayerNorm as LN | from thinc.misc import Residual, LayerNorm as LN | ||||||
| 
 | 
 | ||||||
|  | from spacy.tokens import Doc | ||||||
|  | 
 | ||||||
| """ TODO: this code needs to be implemented in pipes.pyx""" | """ TODO: this code needs to be implemented in pipes.pyx""" | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  | @ -33,34 +36,93 @@ class EL_Model(): | ||||||
|         self.article_encoder = self._simple_encoder(in_width=300, out_width=96) |         self.article_encoder = self._simple_encoder(in_width=300, out_width=96) | ||||||
| 
 | 
 | ||||||
|     def train_model(self, training_dir, entity_descr_output, limit=None, to_print=True): |     def train_model(self, training_dir, entity_descr_output, limit=None, to_print=True): | ||||||
|         instances, pos_entities, neg_entities, doc_by_article = self._get_training_data(training_dir, |         Doc.set_extension("entity_id", default=None) | ||||||
|  | 
 | ||||||
|  |         train_instances, train_pos, train_neg, train_doc = self._get_training_data(training_dir, | ||||||
|                                                                                    entity_descr_output, |                                                                                    entity_descr_output, | ||||||
|  |                                                                                    False, | ||||||
|  |                                                                                    limit, to_print) | ||||||
|  | 
 | ||||||
|  |         dev_instances, dev_pos, dev_neg, dev_doc = self._get_training_data(training_dir, | ||||||
|  |                                                                            entity_descr_output, | ||||||
|  |                                                                            True, | ||||||
|                                                                            limit, to_print) |                                                                            limit, to_print) | ||||||
| 
 | 
 | ||||||
|         if to_print: |         if to_print: | ||||||
|             print("Training on", len(instances), "instance clusters") |             print("Training on", len(train_instances), "instance clusters") | ||||||
|  |             print("Dev test on", len(dev_instances), "instance clusters") | ||||||
|             print() |             print() | ||||||
| 
 | 
 | ||||||
|         self.sgd_entity = self.begin_training(self.entity_encoder) |         self.sgd_entity = self.begin_training(self.entity_encoder) | ||||||
|         self.sgd_article = self.begin_training(self.article_encoder) |         self.sgd_article = self.begin_training(self.article_encoder) | ||||||
| 
 | 
 | ||||||
|  |         self._test_dev(dev_instances, dev_pos, dev_neg, dev_doc) | ||||||
|  | 
 | ||||||
|         losses = {} |         losses = {} | ||||||
| 
 | 
 | ||||||
|         for inst_cluster in instances: |         for inst_cluster in train_instances: | ||||||
|             pos_ex = pos_entities.get(inst_cluster) |             pos_ex = train_pos.get(inst_cluster) | ||||||
|             neg_exs = neg_entities.get(inst_cluster, []) |             neg_exs = train_neg.get(inst_cluster, []) | ||||||
| 
 | 
 | ||||||
|             if pos_ex and neg_exs: |             if pos_ex and neg_exs: | ||||||
|                 article = inst_cluster.split(sep="_")[0] |                 article = inst_cluster.split(sep="_")[0] | ||||||
|                 entity_id = inst_cluster.split(sep="_")[1] |                 entity_id = inst_cluster.split(sep="_")[1] | ||||||
|                 article_doc = doc_by_article[article] |                 article_doc = train_doc[article] | ||||||
|                 self.update(article_doc, pos_ex, neg_exs, losses=losses) |                 self.update(article_doc, pos_ex, neg_exs, losses=losses) | ||||||
|  |                 p, r, fscore = self._test_dev(dev_instances, dev_pos, dev_neg, dev_doc) | ||||||
|  |                 print(round(fscore, 1)) | ||||||
|             # TODO |             # TODO | ||||||
|             # elif not pos_ex: |             # elif not pos_ex: | ||||||
|                 # print("Weird. Couldn't find pos example for",  inst_cluster) |                 # print("Weird. Couldn't find pos example for",  inst_cluster) | ||||||
|             # elif not neg_exs: |             # elif not neg_exs: | ||||||
|                 # print("Weird. Couldn't find neg examples for",  inst_cluster) |                 # print("Weird. Couldn't find neg examples for",  inst_cluster) | ||||||
| 
 | 
 | ||||||
|  |     def _test_dev(self, dev_instances, dev_pos, dev_neg, dev_doc): | ||||||
|  |         predictions = list() | ||||||
|  |         golds = list() | ||||||
|  | 
 | ||||||
|  |         for inst_cluster in dev_instances: | ||||||
|  |             pos_ex = dev_pos.get(inst_cluster) | ||||||
|  |             neg_exs = dev_neg.get(inst_cluster, []) | ||||||
|  |             ex_to_id = dict() | ||||||
|  | 
 | ||||||
|  |             if pos_ex and neg_exs: | ||||||
|  |                 ex_to_id[pos_ex] = pos_ex._.entity_id | ||||||
|  |                 for neg_ex in neg_exs: | ||||||
|  |                     ex_to_id[neg_ex] = neg_ex._.entity_id | ||||||
|  | 
 | ||||||
|  |                 article = inst_cluster.split(sep="_")[0] | ||||||
|  |                 entity_id = inst_cluster.split(sep="_")[1] | ||||||
|  |                 article_doc = dev_doc[article] | ||||||
|  | 
 | ||||||
|  |                 examples = list(neg_exs) | ||||||
|  |                 examples.append(pos_ex) | ||||||
|  |                 shuffle(examples) | ||||||
|  | 
 | ||||||
|  |                 best_entity, lowest_mse = self._predict(examples, article_doc) | ||||||
|  |                 predictions.append(ex_to_id[best_entity]) | ||||||
|  |                 golds.append(ex_to_id[pos_ex]) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  |         # TODO: use lowest_mse and combine with prior probability | ||||||
|  |         p, r, F = run_el.evaluate(predictions, golds, to_print=False) | ||||||
|  |         return p, r, F | ||||||
|  | 
 | ||||||
|  |     def _predict(self, entities, article_doc): | ||||||
|  |         doc_encoding = self.article_encoder([article_doc]) | ||||||
|  | 
 | ||||||
|  |         lowest_mse = None | ||||||
|  |         best_entity = None | ||||||
|  | 
 | ||||||
|  |         for entity in entities: | ||||||
|  |             entity_encoding = self.entity_encoder([entity]) | ||||||
|  |             mse, _ = self._calculate_similarity(doc_encoding, entity_encoding) | ||||||
|  |             if not best_entity or mse < lowest_mse: | ||||||
|  |                 lowest_mse = mse | ||||||
|  |                 best_entity = entity | ||||||
|  | 
 | ||||||
|  |         return best_entity, lowest_mse | ||||||
|  | 
 | ||||||
|     def _simple_encoder(self, in_width, out_width): |     def _simple_encoder(self, in_width, out_width): | ||||||
|         conv_depth = 1 |         conv_depth = 1 | ||||||
|         cnn_maxout_pieces = 3 |         cnn_maxout_pieces = 3 | ||||||
|  | @ -145,7 +207,7 @@ class EL_Model(): | ||||||
|         # print("true index", true_index) |         # print("true index", true_index) | ||||||
|         # print("true prob", entity_probs[true_index]) |         # print("true prob", entity_probs[true_index]) | ||||||
| 
 | 
 | ||||||
|         print(true_mse) |         # print("training loss", true_mse) | ||||||
| 
 | 
 | ||||||
|         # print() |         # print() | ||||||
| 
 | 
 | ||||||
|  | @ -198,13 +260,14 @@ class EL_Model(): | ||||||
|     def _get_labels(self): |     def _get_labels(self): | ||||||
|         return tuple(self.labels) |         return tuple(self.labels) | ||||||
| 
 | 
 | ||||||
|     def _get_training_data(self, training_dir, entity_descr_output, limit, to_print): |     def _get_training_data(self, training_dir, entity_descr_output, dev, limit, to_print): | ||||||
|         id_to_descr = kb_creator._get_id_to_description(entity_descr_output) |         id_to_descr = kb_creator._get_id_to_description(entity_descr_output) | ||||||
| 
 | 
 | ||||||
|         correct_entries, incorrect_entries = training_set_creator.read_training_entities(training_output=training_dir, |         correct_entries, incorrect_entries = training_set_creator.read_training_entities(training_output=training_dir, | ||||||
|                                                                                          collect_correct=True, |                                                                                          collect_correct=True, | ||||||
|                                                                                          collect_incorrect=True) |                                                                                          collect_incorrect=True) | ||||||
| 
 | 
 | ||||||
|  | 
 | ||||||
|         instances = list() |         instances = list() | ||||||
|         local_vectors = list()   # TODO: local vectors |         local_vectors = list()   # TODO: local vectors | ||||||
|         doc_by_article = dict() |         doc_by_article = dict() | ||||||
|  | @ -214,7 +277,7 @@ class EL_Model(): | ||||||
|         cnt = 0 |         cnt = 0 | ||||||
|         for f in listdir(training_dir): |         for f in listdir(training_dir): | ||||||
|             if not limit or cnt < limit: |             if not limit or cnt < limit: | ||||||
|                 if not run_el.is_dev(f): |                 if dev == run_el.is_dev(f): | ||||||
|                     article_id = f.replace(".txt", "") |                     article_id = f.replace(".txt", "") | ||||||
|                     if cnt % 500 == 0 and to_print: |                     if cnt % 500 == 0 and to_print: | ||||||
|                         print(datetime.datetime.now(), "processed", cnt, "files in the dev dataset") |                         print(datetime.datetime.now(), "processed", cnt, "files in the dev dataset") | ||||||
|  | @ -230,6 +293,7 @@ class EL_Model(): | ||||||
|                         if descr: |                         if descr: | ||||||
|                             instances.append(article_id + "_" + mention) |                             instances.append(article_id + "_" + mention) | ||||||
|                             doc_descr = self.nlp(descr) |                             doc_descr = self.nlp(descr) | ||||||
|  |                             doc_descr._.entity_id = entity_pos | ||||||
|                             pos_entities[article_id + "_" + mention] = doc_descr |                             pos_entities[article_id + "_" + mention] = doc_descr | ||||||
| 
 | 
 | ||||||
|                     for mention, entity_negs in incorrect_entries[article_id].items(): |                     for mention, entity_negs in incorrect_entries[article_id].items(): | ||||||
|  | @ -237,6 +301,7 @@ class EL_Model(): | ||||||
|                             descr = id_to_descr.get(entity_neg) |                             descr = id_to_descr.get(entity_neg) | ||||||
|                             if descr: |                             if descr: | ||||||
|                                 doc_descr = self.nlp(descr) |                                 doc_descr = self.nlp(descr) | ||||||
|  |                                 doc_descr._.entity_id = entity_neg | ||||||
|                                 descr_list = neg_entities.get(article_id + "_" + mention, []) |                                 descr_list = neg_entities.get(article_id + "_" + mention, []) | ||||||
|                                 descr_list.append(doc_descr) |                                 descr_list.append(doc_descr) | ||||||
|                                 neg_entities[article_id + "_" + mention] = descr_list |                                 neg_entities[article_id + "_" + mention] = descr_list | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue
	
	Block a user