mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-10-30 23:47:31 +03:00 
			
		
		
		
	upsampling and batch processing
This commit is contained in:
		
							parent
							
								
									1a16490d20
								
							
						
					
					
						commit
						97241a3ed7
					
				|  | @ -78,8 +78,15 @@ def evaluate(predictions, golds, to_print=True): | |||
|     fp = 0 | ||||
|     fn = 0 | ||||
| 
 | ||||
|     corrects = 0 | ||||
|     incorrects = 0 | ||||
| 
 | ||||
|     for pred, gold in zip(predictions, golds): | ||||
|         is_correct = pred == gold | ||||
|         if is_correct: | ||||
|             corrects += 1 | ||||
|         else: | ||||
|             incorrects += 1 | ||||
|         if not pred: | ||||
|             if not is_correct:  # we don't care about tn | ||||
|                 fn += 1 | ||||
|  | @ -98,12 +105,15 @@ def evaluate(predictions, golds, to_print=True): | |||
|     recall = 100 * tp / (tp + fn + 0.0000001) | ||||
|     fscore = 2 * recall * precision / (recall + precision + 0.0000001) | ||||
| 
 | ||||
|     accuracy = corrects / (corrects + incorrects) | ||||
| 
 | ||||
|     if to_print: | ||||
|         print("precision", round(precision, 1), "%") | ||||
|         print("recall", round(recall, 1), "%") | ||||
|         print("Fscore", round(fscore, 1), "%") | ||||
|         print("Accuracy", round(accuracy, 1), "%") | ||||
| 
 | ||||
|     return precision, recall, fscore | ||||
|     return precision, recall, fscore, accuracy | ||||
| 
 | ||||
| 
 | ||||
| def _prepare_pipeline(nlp, kb): | ||||
|  |  | |||
|  | @ -6,6 +6,7 @@ import datetime | |||
| from os import listdir | ||||
| import numpy as np | ||||
| import random | ||||
| from random import shuffle | ||||
| from thinc.neural._classes.convolution import ExtractWindow | ||||
| 
 | ||||
| from examples.pipeline.wiki_entity_linking import run_el, training_set_creator, kb_creator | ||||
|  | @ -26,17 +27,17 @@ from spacy.tokens import Doc | |||
| 
 | ||||
| class EL_Model: | ||||
| 
 | ||||
|     PRINT_LOSS = False | ||||
|     PRINT_F = True | ||||
|     PRINT_TRAIN = False | ||||
|     EPS = 0.0000000005 | ||||
|     CUTOFF = 0.5 | ||||
| 
 | ||||
|     BATCH_SIZE = 5 | ||||
| 
 | ||||
|     INPUT_DIM = 300 | ||||
|     HIDDEN_1_WIDTH = 256  # 10 | ||||
|     HIDDEN_1_WIDTH = 32   # 10 | ||||
|     HIDDEN_2_WIDTH = 32  # 6 | ||||
|     ENTITY_WIDTH = 64     # 4 | ||||
|     ARTICLE_WIDTH = 128   # 8 | ||||
|     DESC_WIDTH = 64     # 4 | ||||
|     ARTICLE_WIDTH = 64   # 8 | ||||
| 
 | ||||
|     DROP = 0.1 | ||||
| 
 | ||||
|  | @ -48,7 +49,7 @@ class EL_Model: | |||
|         self.kb = kb | ||||
| 
 | ||||
|         self._build_cnn(in_width=self.INPUT_DIM, | ||||
|                         entity_width=self.ENTITY_WIDTH, | ||||
|                         desc_width=self.DESC_WIDTH, | ||||
|                         article_width=self.ARTICLE_WIDTH, | ||||
|                         hidden_1_width=self.HIDDEN_1_WIDTH, | ||||
|                         hidden_2_width=self.HIDDEN_2_WIDTH) | ||||
|  | @ -57,121 +58,118 @@ class EL_Model: | |||
|         # raise errors instead of runtime warnings in case of int/float overflow | ||||
|         np.seterr(all='raise') | ||||
| 
 | ||||
|         train_inst, train_pos, train_neg, train_texts = self._get_training_data(training_dir, | ||||
|                                                                                 entity_descr_output, | ||||
|                                                                                 False, | ||||
|                                                                                 trainlimit, | ||||
|                                                                                 balance=True, | ||||
|                                                                                 to_print=False) | ||||
|         train_ent, train_gold, train_desc, train_article, train_texts = self._get_training_data(training_dir, | ||||
|                                                                                                 entity_descr_output, | ||||
|                                                                                                 False, | ||||
|                                                                                                 trainlimit, | ||||
|                                                                                                 to_print=False) | ||||
| 
 | ||||
|         train_pos_entities = [k for k,v in train_gold.items() if v] | ||||
|         train_neg_entities = [k for k,v in train_gold.items() if not v] | ||||
| 
 | ||||
|         train_pos_count = len(train_pos_entities) | ||||
|         train_neg_count = len(train_neg_entities) | ||||
| 
 | ||||
|         # upsample positives to 50-50 distribution | ||||
|         while train_pos_count < train_neg_count: | ||||
|             train_ent.append(random.choice(train_pos_entities)) | ||||
|             train_pos_count += 1 | ||||
| 
 | ||||
|         # upsample negatives to 50-50 distribution | ||||
|         while train_neg_count < train_pos_count: | ||||
|             train_ent.append(random.choice(train_neg_entities)) | ||||
|             train_neg_count += 1 | ||||
| 
 | ||||
|         shuffle(train_ent) | ||||
| 
 | ||||
|         dev_ent, dev_gold, dev_desc, dev_article, dev_texts = self._get_training_data(training_dir, | ||||
|                                                                                       entity_descr_output, | ||||
|                                                                                       True, | ||||
|                                                                                       devlimit, | ||||
|                                                                                       to_print=False) | ||||
|         shuffle(dev_ent) | ||||
| 
 | ||||
|         dev_pos_count = len([g for g in dev_gold.values() if g]) | ||||
|         dev_neg_count = len([g for g in dev_gold.values() if not g]) | ||||
| 
 | ||||
|         dev_inst, dev_pos, dev_neg, dev_texts = self._get_training_data(training_dir, | ||||
|                                                                         entity_descr_output, | ||||
|                                                                         True, | ||||
|                                                                         devlimit, | ||||
|                                                                         balance=False, | ||||
|                                                                         to_print=False) | ||||
|         self._begin_training() | ||||
| 
 | ||||
|         print() | ||||
|         self._test_dev(dev_inst, dev_pos, dev_neg, dev_texts, print_string="dev_random", calc_random=True) | ||||
|         self._test_dev(dev_inst, dev_pos, dev_neg, dev_texts, print_string="dev_pre", avg=False) | ||||
| 
 | ||||
|         instance_pos_count = 0 | ||||
|         instance_neg_count = 0 | ||||
|         self._test_dev(dev_ent, dev_gold, dev_desc, dev_article, dev_texts, print_string="dev_random", calc_random=True) | ||||
|         print() | ||||
|         self._test_dev(dev_ent, dev_gold, dev_desc, dev_article, dev_texts, print_string="dev_pre", avg=True) | ||||
| 
 | ||||
|         if to_print: | ||||
|             print() | ||||
|             print("Training on", len(train_inst.values()), "articles") | ||||
|             print("Dev test on", len(dev_inst.values()), "articles") | ||||
|             print("Training on", len(train_ent), "entities in", len(train_texts), "articles") | ||||
|             print("Training instances pos/neg", train_pos_count, train_neg_count) | ||||
|             print() | ||||
|             print("Dev test on", len(dev_ent), "entities in", len(dev_texts), "articles") | ||||
|             print("Dev instances pos/neg", dev_pos_count, dev_neg_count) | ||||
|             print() | ||||
|             print(" CUTOFF", self.CUTOFF) | ||||
|             print(" INPUT_DIM", self.INPUT_DIM) | ||||
|             print(" HIDDEN_1_WIDTH", self.HIDDEN_1_WIDTH) | ||||
|             print(" ENTITY_WIDTH", self.ENTITY_WIDTH) | ||||
|             print(" DESC_WIDTH", self.DESC_WIDTH) | ||||
|             print(" ARTICLE_WIDTH", self.ARTICLE_WIDTH) | ||||
|             print(" HIDDEN_2_WIDTH", self.HIDDEN_2_WIDTH) | ||||
|             print(" DROP", self.DROP) | ||||
|             print() | ||||
| 
 | ||||
|         # TODO: proper batches. Currently 1 article at the time | ||||
|         # TODO shuffle data (currently positive is always followed by several negatives) | ||||
|         article_count = 0 | ||||
|         for article_id, inst_cluster_set in train_inst.items(): | ||||
|             try: | ||||
|                 # if to_print: | ||||
|                     # print() | ||||
|                     # print(article_count, "Training on article", article_id) | ||||
|                 article_count += 1 | ||||
|                 article_text = train_texts[article_id] | ||||
|                 entities = list() | ||||
|                 golds = list() | ||||
|                 for inst_cluster in inst_cluster_set: | ||||
|                     entities.append(train_pos.get(inst_cluster)) | ||||
|                     golds.append(float(1.0)) | ||||
|                     instance_pos_count += 1 | ||||
|                     for neg_entity in train_neg.get(inst_cluster, []): | ||||
|                         entities.append(neg_entity) | ||||
|                         golds.append(float(0.0)) | ||||
|                         instance_neg_count += 1 | ||||
|         start = 0 | ||||
|         stop = min(self.BATCH_SIZE, len(train_ent)) | ||||
|         processed = 0 | ||||
| 
 | ||||
|                 self.update(article_text=article_text, entities=entities, golds=golds) | ||||
|         while start < len(train_ent): | ||||
|             next_batch = train_ent[start:stop] | ||||
| 
 | ||||
|                 # dev eval | ||||
|                 self._test_dev(dev_inst, dev_pos, dev_neg, dev_texts, print_string="dev_inter_avg", avg=True) | ||||
|             except ValueError as e: | ||||
|                 print("Error in article id", article_id) | ||||
|             golds = [train_gold[e] for e in next_batch] | ||||
|             descs = [train_desc[e] for e in next_batch] | ||||
|             articles = [train_texts[train_article[e]] for e in next_batch] | ||||
| 
 | ||||
|             self.update(entities=next_batch, golds=golds, descs=descs, texts=articles) | ||||
|             self._test_dev(dev_ent, dev_gold, dev_desc, dev_article, dev_texts, print_string="dev_inter", avg=True) | ||||
| 
 | ||||
|             processed += len(next_batch) | ||||
| 
 | ||||
|             start = start + self.BATCH_SIZE | ||||
|             stop = min(stop + self.BATCH_SIZE, len(train_ent)) | ||||
| 
 | ||||
|         if to_print: | ||||
|             print() | ||||
|             print("Trained on", instance_pos_count, "/", instance_neg_count, "instances pos/neg") | ||||
|             print("Trained on", processed, "entities in total") | ||||
| 
 | ||||
|     def _test_dev(self, instances, pos, neg, texts_by_id, print_string, avg=False, calc_random=False): | ||||
|         predictions = list() | ||||
|         golds = list() | ||||
|     def _test_dev(self, entities, gold_by_entity, desc_by_entity, article_by_entity, texts_by_id, print_string, avg=True, calc_random=False): | ||||
|         golds = [gold_by_entity[e] for e in entities] | ||||
| 
 | ||||
|         for article_id, inst_cluster_set in instances.items(): | ||||
|             for inst_cluster in inst_cluster_set: | ||||
|                 pos_ex = pos.get(inst_cluster) | ||||
|                 neg_exs = neg.get(inst_cluster, []) | ||||
|         if calc_random: | ||||
|             predictions = self._predict_random(entities=entities) | ||||
| 
 | ||||
|                 article = inst_cluster.split(sep="_")[0] | ||||
|                 entity_id = inst_cluster.split(sep="_")[1] | ||||
|                 article_doc = self.nlp(texts_by_id[article]) | ||||
|                 entities = [self.nlp(pos_ex)] | ||||
|                 golds.append(float(1.0)) | ||||
|                 for neg_ex in neg_exs: | ||||
|                     entities.append(self.nlp(neg_ex)) | ||||
|                     golds.append(float(0.0)) | ||||
| 
 | ||||
|                 if calc_random: | ||||
|                     preds = self._predict_random(entities=entities) | ||||
|                 else: | ||||
|                     preds = self._predict(article_doc=article_doc, entities=entities, avg=avg) | ||||
|                 predictions.extend(preds) | ||||
|         else: | ||||
|             desc_docs = self.nlp.pipe([desc_by_entity[e] for e in entities]) | ||||
|             article_docs = self.nlp.pipe([texts_by_id[article_by_entity[e]] for e in entities]) | ||||
|             predictions = self._predict(entities=entities, article_docs=article_docs, desc_docs=desc_docs, avg=avg) | ||||
| 
 | ||||
|         # TODO: combine with prior probability | ||||
|         p, r, f = run_el.evaluate(predictions, golds, to_print=False) | ||||
|         if self.PRINT_F: | ||||
|             print("p/r/F", print_string, round(p, 1), round(r, 1), round(f, 1)) | ||||
| 
 | ||||
|         p, r, f, acc = run_el.evaluate(predictions, golds, to_print=False) | ||||
|         loss, gradient = self.get_loss(self.model.ops.asarray(predictions), self.model.ops.asarray(golds)) | ||||
|         if self.PRINT_LOSS: | ||||
|             print("loss", print_string, round(loss, 5)) | ||||
| 
 | ||||
|         print("p/r/F/acc/loss", print_string, round(p, 1), round(r, 1), round(f, 1), round(acc, 2), round(loss, 5)) | ||||
| 
 | ||||
|         return loss, p, r, f | ||||
| 
 | ||||
|     def _predict(self, article_doc, entities, avg=False, apply_threshold=True): | ||||
|     def _predict(self, entities, article_docs, desc_docs, avg=True, apply_threshold=True): | ||||
|         if avg: | ||||
|             with self.article_encoder.use_params(self.sgd_article.averages) \ | ||||
|                  and self.entity_encoder.use_params(self.sgd_entity.averages): | ||||
|                 doc_encoding = self.article_encoder([article_doc])[0] | ||||
|                 entity_encodings = self.entity_encoder(entities) | ||||
|                  and self.desc_encoder.use_params(self.sgd_entity.averages): | ||||
|                 doc_encodings = self.article_encoder(article_docs) | ||||
|                 desc_encodings = self.desc_encoder(desc_docs) | ||||
| 
 | ||||
|         else: | ||||
|             doc_encoding = self.article_encoder([article_doc])[0] | ||||
|             entity_encodings = self.entity_encoder(entities) | ||||
|             doc_encodings = self.article_encoder(article_docs) | ||||
|             desc_encodings = self.desc_encoder(desc_docs) | ||||
| 
 | ||||
|         concat_encodings = [list(entity_encodings[i]) + list(doc_encoding) for i in range(len(entities))] | ||||
|         concat_encodings = [list(desc_encodings[i]) + list(doc_encodings[i]) for i in range(len(entities))] | ||||
|         np_array_list = np.asarray(concat_encodings) | ||||
| 
 | ||||
|         if avg: | ||||
|  | @ -189,16 +187,16 @@ class EL_Model: | |||
| 
 | ||||
|     def _predict_random(self, entities, apply_threshold=True): | ||||
|         if not apply_threshold: | ||||
|             return [float(random.uniform(0,1)) for e in entities] | ||||
|             return [float(random.uniform(0, 1)) for e in entities] | ||||
|         else: | ||||
|             return [float(1.0) if random.uniform(0,1) > self.CUTOFF else float(0.0) for e in entities] | ||||
|             return [float(1.0) if random.uniform(0, 1) > self.CUTOFF else float(0.0) for e in entities] | ||||
| 
 | ||||
|     def _build_cnn(self, in_width, entity_width, article_width, hidden_1_width, hidden_2_width): | ||||
|     def _build_cnn(self, in_width, desc_width, article_width, hidden_1_width, hidden_2_width): | ||||
|         with Model.define_operators({">>": chain, "|": concatenate, "**": clone}): | ||||
|             self.entity_encoder = self._encoder(in_width=in_width, hidden_with=hidden_1_width, end_width=entity_width) | ||||
|             self.desc_encoder = self._encoder(in_width=in_width, hidden_with=hidden_1_width, end_width=desc_width) | ||||
|             self.article_encoder = self._encoder(in_width=in_width, hidden_with=hidden_1_width, end_width=article_width) | ||||
| 
 | ||||
|             in_width = entity_width + article_width | ||||
|             in_width = desc_width + article_width | ||||
|             out_width = hidden_2_width | ||||
| 
 | ||||
|             self.model = Affine(out_width, in_width) \ | ||||
|  | @ -229,80 +227,78 @@ class EL_Model: | |||
| 
 | ||||
|     def _begin_training(self): | ||||
|         self.sgd_article = create_default_optimizer(self.article_encoder.ops) | ||||
|         self.sgd_entity = create_default_optimizer(self.entity_encoder.ops) | ||||
|         self.sgd_entity = create_default_optimizer(self.desc_encoder.ops) | ||||
|         self.sgd = create_default_optimizer(self.model.ops) | ||||
| 
 | ||||
|     @staticmethod | ||||
|     def get_loss(predictions, golds): | ||||
|         d_scores = (predictions - golds) | ||||
|         gradient = d_scores.mean() | ||||
|         loss = (d_scores ** 2).mean() | ||||
|         return loss, d_scores | ||||
|         return loss, gradient | ||||
| 
 | ||||
|     # TODO: multiple docs/articles | ||||
|     def update(self, article_text, entities, golds, apply_threshold=True): | ||||
|         article_doc = self.nlp(article_text) | ||||
|         # entity_docs = list(self.nlp.pipe(entities)) | ||||
|     def update(self, entities, golds, descs, texts): | ||||
|         golds = self.model.ops.asarray(golds) | ||||
| 
 | ||||
|         for entity, gold in zip(entities, golds): | ||||
|             doc_encodings, bp_doc = self.article_encoder.begin_update([article_doc], drop=self.DROP) | ||||
|             doc_encoding = doc_encodings[0] | ||||
|         desc_docs = self.nlp.pipe(descs) | ||||
|         article_docs = self.nlp.pipe(texts) | ||||
| 
 | ||||
|             entity_doc = self.nlp(entity) | ||||
|             # print("entity_docs", type(entity_doc)) | ||||
|         doc_encodings, bp_doc = self.article_encoder.begin_update(article_docs, drop=self.DROP) | ||||
| 
 | ||||
|             entity_encodings, bp_entity = self.entity_encoder.begin_update([entity_doc], drop=self.DROP) | ||||
|             entity_encoding = entity_encodings[0] | ||||
|             # print("entity_encoding", len(entity_encoding), entity_encoding) | ||||
|         desc_encodings, bp_entity = self.desc_encoder.begin_update(desc_docs, drop=self.DROP) | ||||
| 
 | ||||
|             concat_encodings = [list(entity_encoding) + list(doc_encoding)]  #  for i in range(len(entities)) | ||||
|             # print("concat_encodings", len(concat_encodings), concat_encodings) | ||||
|         concat_encodings = [list(desc_encodings[i]) + list(doc_encodings[i]) for i in range(len(entities))] | ||||
| 
 | ||||
|             prediction, bp_model = self.model.begin_update(np.asarray(concat_encodings), drop=self.DROP) | ||||
|             # predictions = self.model.ops.flatten(predictions) | ||||
|         predictions, bp_model = self.model.begin_update(np.asarray(concat_encodings), drop=self.DROP) | ||||
|         predictions = self.model.ops.flatten(predictions) | ||||
| 
 | ||||
|             # print("prediction", prediction) | ||||
|             # golds = self.model.ops.asarray(golds) | ||||
|             # print("gold", gold) | ||||
|         # print("entities", entities) | ||||
|         # print("predictions", predictions) | ||||
|         # print("golds", golds) | ||||
| 
 | ||||
|             loss, gradient = self.get_loss(prediction, gold) | ||||
|         loss, gradient = self.get_loss(predictions, golds) | ||||
| 
 | ||||
|             if self.PRINT_LOSS and self.PRINT_TRAIN: | ||||
|                 print("loss train", round(loss, 5)) | ||||
|         if self.PRINT_TRAIN: | ||||
|             print("loss train", round(loss, 5)) | ||||
| 
 | ||||
|             gradient = float(gradient) | ||||
|             # print("gradient", gradient) | ||||
|             # print("loss", loss) | ||||
|         gradient = float(gradient) | ||||
|         # print("gradient", gradient) | ||||
|         # print("loss", loss) | ||||
| 
 | ||||
|             model_gradient = bp_model(gradient, sgd=self.sgd) | ||||
|             # print("model_gradient", model_gradient) | ||||
|         model_gradient = bp_model(gradient, sgd=self.sgd) | ||||
|         # print("model_gradient", model_gradient) | ||||
| 
 | ||||
|             # concat = entity + doc, but doc is the same within this function (TODO: multiple docs/articles) | ||||
|             doc_gradient = model_gradient[0][self.ENTITY_WIDTH:] | ||||
|             entity_gradients = list() | ||||
|             for x in model_gradient: | ||||
|                 entity_gradients.append(list(x[0:self.ENTITY_WIDTH])) | ||||
|         # concat = desc + doc, but doc is the same within this function (TODO: multiple docs/articles) | ||||
|         doc_gradient = model_gradient[0][self.DESC_WIDTH:] | ||||
|         entity_gradients = list() | ||||
|         for x in model_gradient: | ||||
|             entity_gradients.append(list(x[0:self.DESC_WIDTH])) | ||||
| 
 | ||||
|             # print("doc_gradient", doc_gradient) | ||||
|             # print("entity_gradients", entity_gradients) | ||||
|         # print("doc_gradient", doc_gradient) | ||||
|         # print("entity_gradients", entity_gradients) | ||||
| 
 | ||||
|             bp_doc([doc_gradient], sgd=self.sgd_article) | ||||
|             bp_entity(entity_gradients, sgd=self.sgd_entity) | ||||
|         bp_doc([doc_gradient], sgd=self.sgd_article) | ||||
|         bp_entity(entity_gradients, sgd=self.sgd_entity) | ||||
| 
 | ||||
|     def _get_training_data(self, training_dir, entity_descr_output, dev, limit, balance, to_print): | ||||
|     def _get_training_data(self, training_dir, entity_descr_output, dev, limit, to_print): | ||||
|         id_to_descr = kb_creator._get_id_to_description(entity_descr_output) | ||||
| 
 | ||||
|         correct_entries, incorrect_entries = training_set_creator.read_training_entities(training_output=training_dir, | ||||
|                                                                                          collect_correct=True, | ||||
|                                                                                          collect_incorrect=True) | ||||
| 
 | ||||
|         instance_by_article = dict() | ||||
|         local_vectors = list()   # TODO: local vectors | ||||
|         text_by_article = dict() | ||||
|         pos_entities = dict() | ||||
|         neg_entities = dict() | ||||
|         gold_by_entity = dict() | ||||
|         desc_by_entity = dict() | ||||
|         article_by_entity = dict() | ||||
|         entities = list() | ||||
| 
 | ||||
|         cnt = 0 | ||||
|         for f in listdir(training_dir): | ||||
|         next_entity_nr = 0 | ||||
|         files = listdir(training_dir) | ||||
|         shuffle(files) | ||||
|         for f in files: | ||||
|             if not limit or cnt < limit: | ||||
|                 if dev == run_el.is_dev(f): | ||||
|                     article_id = f.replace(".txt", "") | ||||
|  | @ -313,29 +309,29 @@ class EL_Model: | |||
|                         with open(os.path.join(training_dir, f), mode="r", encoding='utf8') as file: | ||||
|                             text = file.read() | ||||
|                             text_by_article[article_id] = text | ||||
|                             instance_by_article[article_id] = set() | ||||
| 
 | ||||
|                     for mention, entity_pos in correct_entries[article_id].items(): | ||||
|                         descr = id_to_descr.get(entity_pos) | ||||
|                         if descr: | ||||
|                             instance_by_article[article_id].add(article_id + "_" + mention) | ||||
|                             pos_entities[article_id + "_" + mention] = descr | ||||
|                             entities.append(next_entity_nr) | ||||
|                             gold_by_entity[next_entity_nr] = 1 | ||||
|                             desc_by_entity[next_entity_nr] = descr | ||||
|                             article_by_entity[next_entity_nr] = article_id | ||||
|                             next_entity_nr += 1 | ||||
| 
 | ||||
|                     for mention, entity_negs in incorrect_entries[article_id].items(): | ||||
|                         if not balance or pos_entities.get(article_id + "_" + mention): | ||||
|                             neg_count = 0 | ||||
|                             for entity_neg in entity_negs: | ||||
|                                 # if balance, keep only 1 negative instance for each positive instance | ||||
|                                 if neg_count < 1 or not balance: | ||||
|                                     descr = id_to_descr.get(entity_neg) | ||||
|                                     if descr: | ||||
|                                         descr_list = neg_entities.get(article_id + "_" + mention, []) | ||||
|                                         descr_list.append(descr) | ||||
|                                         neg_entities[article_id + "_" + mention] = descr_list | ||||
|                                         neg_count += 1 | ||||
|                         for entity_neg in entity_negs: | ||||
|                             descr = id_to_descr.get(entity_neg) | ||||
|                             if descr: | ||||
|                                 entities.append(next_entity_nr) | ||||
|                                 gold_by_entity[next_entity_nr] = 0 | ||||
|                                 desc_by_entity[next_entity_nr] = descr | ||||
|                                 article_by_entity[next_entity_nr] = article_id | ||||
|                                 next_entity_nr += 1 | ||||
| 
 | ||||
|         if to_print: | ||||
|             print() | ||||
|             print("Processed", cnt, "training articles, dev=" + str(dev)) | ||||
|             print() | ||||
|         return instance_by_article, pos_entities, neg_entities, text_by_article | ||||
|         return entities, gold_by_entity, desc_by_entity, article_by_entity, text_by_article | ||||
| 
 | ||||
|  |  | |||
|  | @ -111,7 +111,7 @@ if __name__ == "__main__": | |||
|         print("STEP 6: training", datetime.datetime.now()) | ||||
|         my_nlp = spacy.load('en_core_web_md') | ||||
|         trainer = EL_Model(kb=my_kb, nlp=my_nlp) | ||||
|         trainer.train_model(training_dir=TRAINING_DIR, entity_descr_output=ENTITY_DESCR, trainlimit=100, devlimit=20) | ||||
|         trainer.train_model(training_dir=TRAINING_DIR, entity_descr_output=ENTITY_DESCR, trainlimit=400, devlimit=50) | ||||
|         print() | ||||
| 
 | ||||
|     # STEP 7: apply the EL algorithm on the dev dataset | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue
	
	Block a user