mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-10-31 16:07:41 +03:00 
			
		
		
		
	upsampling and batch processing
This commit is contained in:
		
							parent
							
								
									1a16490d20
								
							
						
					
					
						commit
						97241a3ed7
					
				|  | @ -78,8 +78,15 @@ def evaluate(predictions, golds, to_print=True): | ||||||
|     fp = 0 |     fp = 0 | ||||||
|     fn = 0 |     fn = 0 | ||||||
| 
 | 
 | ||||||
|  |     corrects = 0 | ||||||
|  |     incorrects = 0 | ||||||
|  | 
 | ||||||
|     for pred, gold in zip(predictions, golds): |     for pred, gold in zip(predictions, golds): | ||||||
|         is_correct = pred == gold |         is_correct = pred == gold | ||||||
|  |         if is_correct: | ||||||
|  |             corrects += 1 | ||||||
|  |         else: | ||||||
|  |             incorrects += 1 | ||||||
|         if not pred: |         if not pred: | ||||||
|             if not is_correct:  # we don't care about tn |             if not is_correct:  # we don't care about tn | ||||||
|                 fn += 1 |                 fn += 1 | ||||||
|  | @ -98,12 +105,15 @@ def evaluate(predictions, golds, to_print=True): | ||||||
|     recall = 100 * tp / (tp + fn + 0.0000001) |     recall = 100 * tp / (tp + fn + 0.0000001) | ||||||
|     fscore = 2 * recall * precision / (recall + precision + 0.0000001) |     fscore = 2 * recall * precision / (recall + precision + 0.0000001) | ||||||
| 
 | 
 | ||||||
|  |     accuracy = corrects / (corrects + incorrects) | ||||||
|  | 
 | ||||||
|     if to_print: |     if to_print: | ||||||
|         print("precision", round(precision, 1), "%") |         print("precision", round(precision, 1), "%") | ||||||
|         print("recall", round(recall, 1), "%") |         print("recall", round(recall, 1), "%") | ||||||
|         print("Fscore", round(fscore, 1), "%") |         print("Fscore", round(fscore, 1), "%") | ||||||
|  |         print("Accuracy", round(accuracy, 1), "%") | ||||||
| 
 | 
 | ||||||
|     return precision, recall, fscore |     return precision, recall, fscore, accuracy | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def _prepare_pipeline(nlp, kb): | def _prepare_pipeline(nlp, kb): | ||||||
|  |  | ||||||
|  | @ -6,6 +6,7 @@ import datetime | ||||||
| from os import listdir | from os import listdir | ||||||
| import numpy as np | import numpy as np | ||||||
| import random | import random | ||||||
|  | from random import shuffle | ||||||
| from thinc.neural._classes.convolution import ExtractWindow | from thinc.neural._classes.convolution import ExtractWindow | ||||||
| 
 | 
 | ||||||
| from examples.pipeline.wiki_entity_linking import run_el, training_set_creator, kb_creator | from examples.pipeline.wiki_entity_linking import run_el, training_set_creator, kb_creator | ||||||
|  | @ -26,17 +27,17 @@ from spacy.tokens import Doc | ||||||
| 
 | 
 | ||||||
| class EL_Model: | class EL_Model: | ||||||
| 
 | 
 | ||||||
|     PRINT_LOSS = False |  | ||||||
|     PRINT_F = True |  | ||||||
|     PRINT_TRAIN = False |     PRINT_TRAIN = False | ||||||
|     EPS = 0.0000000005 |     EPS = 0.0000000005 | ||||||
|     CUTOFF = 0.5 |     CUTOFF = 0.5 | ||||||
| 
 | 
 | ||||||
|  |     BATCH_SIZE = 5 | ||||||
|  | 
 | ||||||
|     INPUT_DIM = 300 |     INPUT_DIM = 300 | ||||||
|     HIDDEN_1_WIDTH = 256  # 10 |     HIDDEN_1_WIDTH = 32   # 10 | ||||||
|     HIDDEN_2_WIDTH = 32  # 6 |     HIDDEN_2_WIDTH = 32  # 6 | ||||||
|     ENTITY_WIDTH = 64     # 4 |     DESC_WIDTH = 64     # 4 | ||||||
|     ARTICLE_WIDTH = 128   # 8 |     ARTICLE_WIDTH = 64   # 8 | ||||||
| 
 | 
 | ||||||
|     DROP = 0.1 |     DROP = 0.1 | ||||||
| 
 | 
 | ||||||
|  | @ -48,7 +49,7 @@ class EL_Model: | ||||||
|         self.kb = kb |         self.kb = kb | ||||||
| 
 | 
 | ||||||
|         self._build_cnn(in_width=self.INPUT_DIM, |         self._build_cnn(in_width=self.INPUT_DIM, | ||||||
|                         entity_width=self.ENTITY_WIDTH, |                         desc_width=self.DESC_WIDTH, | ||||||
|                         article_width=self.ARTICLE_WIDTH, |                         article_width=self.ARTICLE_WIDTH, | ||||||
|                         hidden_1_width=self.HIDDEN_1_WIDTH, |                         hidden_1_width=self.HIDDEN_1_WIDTH, | ||||||
|                         hidden_2_width=self.HIDDEN_2_WIDTH) |                         hidden_2_width=self.HIDDEN_2_WIDTH) | ||||||
|  | @ -57,121 +58,118 @@ class EL_Model: | ||||||
|         # raise errors instead of runtime warnings in case of int/float overflow |         # raise errors instead of runtime warnings in case of int/float overflow | ||||||
|         np.seterr(all='raise') |         np.seterr(all='raise') | ||||||
| 
 | 
 | ||||||
|         train_inst, train_pos, train_neg, train_texts = self._get_training_data(training_dir, |         train_ent, train_gold, train_desc, train_article, train_texts = self._get_training_data(training_dir, | ||||||
|                                                                                 entity_descr_output, |                                                                                                 entity_descr_output, | ||||||
|                                                                                 False, |                                                                                                 False, | ||||||
|                                                                                 trainlimit, |                                                                                                 trainlimit, | ||||||
|                                                                                 balance=True, |                                                                                                 to_print=False) | ||||||
|                                                                                 to_print=False) | 
 | ||||||
|  |         train_pos_entities = [k for k,v in train_gold.items() if v] | ||||||
|  |         train_neg_entities = [k for k,v in train_gold.items() if not v] | ||||||
|  | 
 | ||||||
|  |         train_pos_count = len(train_pos_entities) | ||||||
|  |         train_neg_count = len(train_neg_entities) | ||||||
|  | 
 | ||||||
|  |         # upsample positives to 50-50 distribution | ||||||
|  |         while train_pos_count < train_neg_count: | ||||||
|  |             train_ent.append(random.choice(train_pos_entities)) | ||||||
|  |             train_pos_count += 1 | ||||||
|  | 
 | ||||||
|  |         # upsample negatives to 50-50 distribution | ||||||
|  |         while train_neg_count < train_pos_count: | ||||||
|  |             train_ent.append(random.choice(train_neg_entities)) | ||||||
|  |             train_neg_count += 1 | ||||||
|  | 
 | ||||||
|  |         shuffle(train_ent) | ||||||
|  | 
 | ||||||
|  |         dev_ent, dev_gold, dev_desc, dev_article, dev_texts = self._get_training_data(training_dir, | ||||||
|  |                                                                                       entity_descr_output, | ||||||
|  |                                                                                       True, | ||||||
|  |                                                                                       devlimit, | ||||||
|  |                                                                                       to_print=False) | ||||||
|  |         shuffle(dev_ent) | ||||||
|  | 
 | ||||||
|  |         dev_pos_count = len([g for g in dev_gold.values() if g]) | ||||||
|  |         dev_neg_count = len([g for g in dev_gold.values() if not g]) | ||||||
| 
 | 
 | ||||||
|         dev_inst, dev_pos, dev_neg, dev_texts = self._get_training_data(training_dir, |  | ||||||
|                                                                         entity_descr_output, |  | ||||||
|                                                                         True, |  | ||||||
|                                                                         devlimit, |  | ||||||
|                                                                         balance=False, |  | ||||||
|                                                                         to_print=False) |  | ||||||
|         self._begin_training() |         self._begin_training() | ||||||
| 
 | 
 | ||||||
|         print() |         print() | ||||||
|         self._test_dev(dev_inst, dev_pos, dev_neg, dev_texts, print_string="dev_random", calc_random=True) |         self._test_dev(dev_ent, dev_gold, dev_desc, dev_article, dev_texts, print_string="dev_random", calc_random=True) | ||||||
|         self._test_dev(dev_inst, dev_pos, dev_neg, dev_texts, print_string="dev_pre", avg=False) |         print() | ||||||
| 
 |         self._test_dev(dev_ent, dev_gold, dev_desc, dev_article, dev_texts, print_string="dev_pre", avg=True) | ||||||
|         instance_pos_count = 0 |  | ||||||
|         instance_neg_count = 0 |  | ||||||
| 
 | 
 | ||||||
|         if to_print: |         if to_print: | ||||||
|             print() |             print() | ||||||
|             print("Training on", len(train_inst.values()), "articles") |             print("Training on", len(train_ent), "entities in", len(train_texts), "articles") | ||||||
|             print("Dev test on", len(dev_inst.values()), "articles") |             print("Training instances pos/neg", train_pos_count, train_neg_count) | ||||||
|  |             print() | ||||||
|  |             print("Dev test on", len(dev_ent), "entities in", len(dev_texts), "articles") | ||||||
|  |             print("Dev instances pos/neg", dev_pos_count, dev_neg_count) | ||||||
|             print() |             print() | ||||||
|             print(" CUTOFF", self.CUTOFF) |             print(" CUTOFF", self.CUTOFF) | ||||||
|             print(" INPUT_DIM", self.INPUT_DIM) |             print(" INPUT_DIM", self.INPUT_DIM) | ||||||
|             print(" HIDDEN_1_WIDTH", self.HIDDEN_1_WIDTH) |             print(" HIDDEN_1_WIDTH", self.HIDDEN_1_WIDTH) | ||||||
|             print(" ENTITY_WIDTH", self.ENTITY_WIDTH) |             print(" DESC_WIDTH", self.DESC_WIDTH) | ||||||
|             print(" ARTICLE_WIDTH", self.ARTICLE_WIDTH) |             print(" ARTICLE_WIDTH", self.ARTICLE_WIDTH) | ||||||
|             print(" HIDDEN_2_WIDTH", self.HIDDEN_2_WIDTH) |             print(" HIDDEN_2_WIDTH", self.HIDDEN_2_WIDTH) | ||||||
|             print(" DROP", self.DROP) |             print(" DROP", self.DROP) | ||||||
|             print() |             print() | ||||||
| 
 | 
 | ||||||
|         # TODO: proper batches. Currently 1 article at the time |         start = 0 | ||||||
|         # TODO shuffle data (currently positive is always followed by several negatives) |         stop = min(self.BATCH_SIZE, len(train_ent)) | ||||||
|         article_count = 0 |         processed = 0 | ||||||
|         for article_id, inst_cluster_set in train_inst.items(): |  | ||||||
|             try: |  | ||||||
|                 # if to_print: |  | ||||||
|                     # print() |  | ||||||
|                     # print(article_count, "Training on article", article_id) |  | ||||||
|                 article_count += 1 |  | ||||||
|                 article_text = train_texts[article_id] |  | ||||||
|                 entities = list() |  | ||||||
|                 golds = list() |  | ||||||
|                 for inst_cluster in inst_cluster_set: |  | ||||||
|                     entities.append(train_pos.get(inst_cluster)) |  | ||||||
|                     golds.append(float(1.0)) |  | ||||||
|                     instance_pos_count += 1 |  | ||||||
|                     for neg_entity in train_neg.get(inst_cluster, []): |  | ||||||
|                         entities.append(neg_entity) |  | ||||||
|                         golds.append(float(0.0)) |  | ||||||
|                         instance_neg_count += 1 |  | ||||||
| 
 | 
 | ||||||
|                 self.update(article_text=article_text, entities=entities, golds=golds) |         while start < len(train_ent): | ||||||
|  |             next_batch = train_ent[start:stop] | ||||||
| 
 | 
 | ||||||
|                 # dev eval |             golds = [train_gold[e] for e in next_batch] | ||||||
|                 self._test_dev(dev_inst, dev_pos, dev_neg, dev_texts, print_string="dev_inter_avg", avg=True) |             descs = [train_desc[e] for e in next_batch] | ||||||
|             except ValueError as e: |             articles = [train_texts[train_article[e]] for e in next_batch] | ||||||
|                 print("Error in article id", article_id) | 
 | ||||||
|  |             self.update(entities=next_batch, golds=golds, descs=descs, texts=articles) | ||||||
|  |             self._test_dev(dev_ent, dev_gold, dev_desc, dev_article, dev_texts, print_string="dev_inter", avg=True) | ||||||
|  | 
 | ||||||
|  |             processed += len(next_batch) | ||||||
|  | 
 | ||||||
|  |             start = start + self.BATCH_SIZE | ||||||
|  |             stop = min(stop + self.BATCH_SIZE, len(train_ent)) | ||||||
| 
 | 
 | ||||||
|         if to_print: |         if to_print: | ||||||
|             print() |             print() | ||||||
|             print("Trained on", instance_pos_count, "/", instance_neg_count, "instances pos/neg") |             print("Trained on", processed, "entities in total") | ||||||
| 
 | 
 | ||||||
|     def _test_dev(self, instances, pos, neg, texts_by_id, print_string, avg=False, calc_random=False): |     def _test_dev(self, entities, gold_by_entity, desc_by_entity, article_by_entity, texts_by_id, print_string, avg=True, calc_random=False): | ||||||
|         predictions = list() |         golds = [gold_by_entity[e] for e in entities] | ||||||
|         golds = list() |  | ||||||
| 
 | 
 | ||||||
|         for article_id, inst_cluster_set in instances.items(): |         if calc_random: | ||||||
|             for inst_cluster in inst_cluster_set: |             predictions = self._predict_random(entities=entities) | ||||||
|                 pos_ex = pos.get(inst_cluster) |  | ||||||
|                 neg_exs = neg.get(inst_cluster, []) |  | ||||||
| 
 | 
 | ||||||
|                 article = inst_cluster.split(sep="_")[0] |         else: | ||||||
|                 entity_id = inst_cluster.split(sep="_")[1] |             desc_docs = self.nlp.pipe([desc_by_entity[e] for e in entities]) | ||||||
|                 article_doc = self.nlp(texts_by_id[article]) |             article_docs = self.nlp.pipe([texts_by_id[article_by_entity[e]] for e in entities]) | ||||||
|                 entities = [self.nlp(pos_ex)] |             predictions = self._predict(entities=entities, article_docs=article_docs, desc_docs=desc_docs, avg=avg) | ||||||
|                 golds.append(float(1.0)) |  | ||||||
|                 for neg_ex in neg_exs: |  | ||||||
|                     entities.append(self.nlp(neg_ex)) |  | ||||||
|                     golds.append(float(0.0)) |  | ||||||
| 
 |  | ||||||
|                 if calc_random: |  | ||||||
|                     preds = self._predict_random(entities=entities) |  | ||||||
|                 else: |  | ||||||
|                     preds = self._predict(article_doc=article_doc, entities=entities, avg=avg) |  | ||||||
|                 predictions.extend(preds) |  | ||||||
| 
 | 
 | ||||||
|         # TODO: combine with prior probability |         # TODO: combine with prior probability | ||||||
|         p, r, f = run_el.evaluate(predictions, golds, to_print=False) |         p, r, f, acc = run_el.evaluate(predictions, golds, to_print=False) | ||||||
|         if self.PRINT_F: |  | ||||||
|             print("p/r/F", print_string, round(p, 1), round(r, 1), round(f, 1)) |  | ||||||
| 
 |  | ||||||
|         loss, gradient = self.get_loss(self.model.ops.asarray(predictions), self.model.ops.asarray(golds)) |         loss, gradient = self.get_loss(self.model.ops.asarray(predictions), self.model.ops.asarray(golds)) | ||||||
|         if self.PRINT_LOSS: | 
 | ||||||
|             print("loss", print_string, round(loss, 5)) |         print("p/r/F/acc/loss", print_string, round(p, 1), round(r, 1), round(f, 1), round(acc, 2), round(loss, 5)) | ||||||
| 
 | 
 | ||||||
|         return loss, p, r, f |         return loss, p, r, f | ||||||
| 
 | 
 | ||||||
|     def _predict(self, article_doc, entities, avg=False, apply_threshold=True): |     def _predict(self, entities, article_docs, desc_docs, avg=True, apply_threshold=True): | ||||||
|         if avg: |         if avg: | ||||||
|             with self.article_encoder.use_params(self.sgd_article.averages) \ |             with self.article_encoder.use_params(self.sgd_article.averages) \ | ||||||
|                  and self.entity_encoder.use_params(self.sgd_entity.averages): |                  and self.desc_encoder.use_params(self.sgd_entity.averages): | ||||||
|                 doc_encoding = self.article_encoder([article_doc])[0] |                 doc_encodings = self.article_encoder(article_docs) | ||||||
|                 entity_encodings = self.entity_encoder(entities) |                 desc_encodings = self.desc_encoder(desc_docs) | ||||||
| 
 | 
 | ||||||
|         else: |         else: | ||||||
|             doc_encoding = self.article_encoder([article_doc])[0] |             doc_encodings = self.article_encoder(article_docs) | ||||||
|             entity_encodings = self.entity_encoder(entities) |             desc_encodings = self.desc_encoder(desc_docs) | ||||||
| 
 | 
 | ||||||
|         concat_encodings = [list(entity_encodings[i]) + list(doc_encoding) for i in range(len(entities))] |         concat_encodings = [list(desc_encodings[i]) + list(doc_encodings[i]) for i in range(len(entities))] | ||||||
|         np_array_list = np.asarray(concat_encodings) |         np_array_list = np.asarray(concat_encodings) | ||||||
| 
 | 
 | ||||||
|         if avg: |         if avg: | ||||||
|  | @ -189,16 +187,16 @@ class EL_Model: | ||||||
| 
 | 
 | ||||||
|     def _predict_random(self, entities, apply_threshold=True): |     def _predict_random(self, entities, apply_threshold=True): | ||||||
|         if not apply_threshold: |         if not apply_threshold: | ||||||
|             return [float(random.uniform(0,1)) for e in entities] |             return [float(random.uniform(0, 1)) for e in entities] | ||||||
|         else: |         else: | ||||||
|             return [float(1.0) if random.uniform(0,1) > self.CUTOFF else float(0.0) for e in entities] |             return [float(1.0) if random.uniform(0, 1) > self.CUTOFF else float(0.0) for e in entities] | ||||||
| 
 | 
 | ||||||
|     def _build_cnn(self, in_width, entity_width, article_width, hidden_1_width, hidden_2_width): |     def _build_cnn(self, in_width, desc_width, article_width, hidden_1_width, hidden_2_width): | ||||||
|         with Model.define_operators({">>": chain, "|": concatenate, "**": clone}): |         with Model.define_operators({">>": chain, "|": concatenate, "**": clone}): | ||||||
|             self.entity_encoder = self._encoder(in_width=in_width, hidden_with=hidden_1_width, end_width=entity_width) |             self.desc_encoder = self._encoder(in_width=in_width, hidden_with=hidden_1_width, end_width=desc_width) | ||||||
|             self.article_encoder = self._encoder(in_width=in_width, hidden_with=hidden_1_width, end_width=article_width) |             self.article_encoder = self._encoder(in_width=in_width, hidden_with=hidden_1_width, end_width=article_width) | ||||||
| 
 | 
 | ||||||
|             in_width = entity_width + article_width |             in_width = desc_width + article_width | ||||||
|             out_width = hidden_2_width |             out_width = hidden_2_width | ||||||
| 
 | 
 | ||||||
|             self.model = Affine(out_width, in_width) \ |             self.model = Affine(out_width, in_width) \ | ||||||
|  | @ -229,80 +227,78 @@ class EL_Model: | ||||||
| 
 | 
 | ||||||
|     def _begin_training(self): |     def _begin_training(self): | ||||||
|         self.sgd_article = create_default_optimizer(self.article_encoder.ops) |         self.sgd_article = create_default_optimizer(self.article_encoder.ops) | ||||||
|         self.sgd_entity = create_default_optimizer(self.entity_encoder.ops) |         self.sgd_entity = create_default_optimizer(self.desc_encoder.ops) | ||||||
|         self.sgd = create_default_optimizer(self.model.ops) |         self.sgd = create_default_optimizer(self.model.ops) | ||||||
| 
 | 
 | ||||||
|     @staticmethod |     @staticmethod | ||||||
|     def get_loss(predictions, golds): |     def get_loss(predictions, golds): | ||||||
|         d_scores = (predictions - golds) |         d_scores = (predictions - golds) | ||||||
|  |         gradient = d_scores.mean() | ||||||
|         loss = (d_scores ** 2).mean() |         loss = (d_scores ** 2).mean() | ||||||
|         return loss, d_scores |         return loss, gradient | ||||||
| 
 | 
 | ||||||
|     # TODO: multiple docs/articles |     def update(self, entities, golds, descs, texts): | ||||||
|     def update(self, article_text, entities, golds, apply_threshold=True): |         golds = self.model.ops.asarray(golds) | ||||||
|         article_doc = self.nlp(article_text) |  | ||||||
|         # entity_docs = list(self.nlp.pipe(entities)) |  | ||||||
| 
 | 
 | ||||||
|         for entity, gold in zip(entities, golds): |         desc_docs = self.nlp.pipe(descs) | ||||||
|             doc_encodings, bp_doc = self.article_encoder.begin_update([article_doc], drop=self.DROP) |         article_docs = self.nlp.pipe(texts) | ||||||
|             doc_encoding = doc_encodings[0] |  | ||||||
| 
 | 
 | ||||||
|             entity_doc = self.nlp(entity) |         doc_encodings, bp_doc = self.article_encoder.begin_update(article_docs, drop=self.DROP) | ||||||
|             # print("entity_docs", type(entity_doc)) |  | ||||||
| 
 | 
 | ||||||
|             entity_encodings, bp_entity = self.entity_encoder.begin_update([entity_doc], drop=self.DROP) |         desc_encodings, bp_entity = self.desc_encoder.begin_update(desc_docs, drop=self.DROP) | ||||||
|             entity_encoding = entity_encodings[0] |  | ||||||
|             # print("entity_encoding", len(entity_encoding), entity_encoding) |  | ||||||
| 
 | 
 | ||||||
|             concat_encodings = [list(entity_encoding) + list(doc_encoding)]  #  for i in range(len(entities)) |         concat_encodings = [list(desc_encodings[i]) + list(doc_encodings[i]) for i in range(len(entities))] | ||||||
|             # print("concat_encodings", len(concat_encodings), concat_encodings) |  | ||||||
| 
 | 
 | ||||||
|             prediction, bp_model = self.model.begin_update(np.asarray(concat_encodings), drop=self.DROP) |         predictions, bp_model = self.model.begin_update(np.asarray(concat_encodings), drop=self.DROP) | ||||||
|             # predictions = self.model.ops.flatten(predictions) |         predictions = self.model.ops.flatten(predictions) | ||||||
| 
 | 
 | ||||||
|             # print("prediction", prediction) |         # print("entities", entities) | ||||||
|             # golds = self.model.ops.asarray(golds) |         # print("predictions", predictions) | ||||||
|             # print("gold", gold) |         # print("golds", golds) | ||||||
| 
 | 
 | ||||||
|             loss, gradient = self.get_loss(prediction, gold) |         loss, gradient = self.get_loss(predictions, golds) | ||||||
| 
 | 
 | ||||||
|             if self.PRINT_LOSS and self.PRINT_TRAIN: |         if self.PRINT_TRAIN: | ||||||
|                 print("loss train", round(loss, 5)) |             print("loss train", round(loss, 5)) | ||||||
| 
 | 
 | ||||||
|             gradient = float(gradient) |         gradient = float(gradient) | ||||||
|             # print("gradient", gradient) |         # print("gradient", gradient) | ||||||
|             # print("loss", loss) |         # print("loss", loss) | ||||||
| 
 | 
 | ||||||
|             model_gradient = bp_model(gradient, sgd=self.sgd) |         model_gradient = bp_model(gradient, sgd=self.sgd) | ||||||
|             # print("model_gradient", model_gradient) |         # print("model_gradient", model_gradient) | ||||||
| 
 | 
 | ||||||
|             # concat = entity + doc, but doc is the same within this function (TODO: multiple docs/articles) |         # concat = desc + doc, but doc is the same within this function (TODO: multiple docs/articles) | ||||||
|             doc_gradient = model_gradient[0][self.ENTITY_WIDTH:] |         doc_gradient = model_gradient[0][self.DESC_WIDTH:] | ||||||
|             entity_gradients = list() |         entity_gradients = list() | ||||||
|             for x in model_gradient: |         for x in model_gradient: | ||||||
|                 entity_gradients.append(list(x[0:self.ENTITY_WIDTH])) |             entity_gradients.append(list(x[0:self.DESC_WIDTH])) | ||||||
| 
 | 
 | ||||||
|             # print("doc_gradient", doc_gradient) |         # print("doc_gradient", doc_gradient) | ||||||
|             # print("entity_gradients", entity_gradients) |         # print("entity_gradients", entity_gradients) | ||||||
| 
 | 
 | ||||||
|             bp_doc([doc_gradient], sgd=self.sgd_article) |         bp_doc([doc_gradient], sgd=self.sgd_article) | ||||||
|             bp_entity(entity_gradients, sgd=self.sgd_entity) |         bp_entity(entity_gradients, sgd=self.sgd_entity) | ||||||
| 
 | 
 | ||||||
|     def _get_training_data(self, training_dir, entity_descr_output, dev, limit, balance, to_print): |     def _get_training_data(self, training_dir, entity_descr_output, dev, limit, to_print): | ||||||
|         id_to_descr = kb_creator._get_id_to_description(entity_descr_output) |         id_to_descr = kb_creator._get_id_to_description(entity_descr_output) | ||||||
| 
 | 
 | ||||||
|         correct_entries, incorrect_entries = training_set_creator.read_training_entities(training_output=training_dir, |         correct_entries, incorrect_entries = training_set_creator.read_training_entities(training_output=training_dir, | ||||||
|                                                                                          collect_correct=True, |                                                                                          collect_correct=True, | ||||||
|                                                                                          collect_incorrect=True) |                                                                                          collect_incorrect=True) | ||||||
| 
 | 
 | ||||||
|         instance_by_article = dict() |  | ||||||
|         local_vectors = list()   # TODO: local vectors |         local_vectors = list()   # TODO: local vectors | ||||||
|         text_by_article = dict() |         text_by_article = dict() | ||||||
|         pos_entities = dict() |         gold_by_entity = dict() | ||||||
|         neg_entities = dict() |         desc_by_entity = dict() | ||||||
|  |         article_by_entity = dict() | ||||||
|  |         entities = list() | ||||||
| 
 | 
 | ||||||
|         cnt = 0 |         cnt = 0 | ||||||
|         for f in listdir(training_dir): |         next_entity_nr = 0 | ||||||
|  |         files = listdir(training_dir) | ||||||
|  |         shuffle(files) | ||||||
|  |         for f in files: | ||||||
|             if not limit or cnt < limit: |             if not limit or cnt < limit: | ||||||
|                 if dev == run_el.is_dev(f): |                 if dev == run_el.is_dev(f): | ||||||
|                     article_id = f.replace(".txt", "") |                     article_id = f.replace(".txt", "") | ||||||
|  | @ -313,29 +309,29 @@ class EL_Model: | ||||||
|                         with open(os.path.join(training_dir, f), mode="r", encoding='utf8') as file: |                         with open(os.path.join(training_dir, f), mode="r", encoding='utf8') as file: | ||||||
|                             text = file.read() |                             text = file.read() | ||||||
|                             text_by_article[article_id] = text |                             text_by_article[article_id] = text | ||||||
|                             instance_by_article[article_id] = set() |  | ||||||
| 
 | 
 | ||||||
|                     for mention, entity_pos in correct_entries[article_id].items(): |                     for mention, entity_pos in correct_entries[article_id].items(): | ||||||
|                         descr = id_to_descr.get(entity_pos) |                         descr = id_to_descr.get(entity_pos) | ||||||
|                         if descr: |                         if descr: | ||||||
|                             instance_by_article[article_id].add(article_id + "_" + mention) |                             entities.append(next_entity_nr) | ||||||
|                             pos_entities[article_id + "_" + mention] = descr |                             gold_by_entity[next_entity_nr] = 1 | ||||||
|  |                             desc_by_entity[next_entity_nr] = descr | ||||||
|  |                             article_by_entity[next_entity_nr] = article_id | ||||||
|  |                             next_entity_nr += 1 | ||||||
| 
 | 
 | ||||||
|                     for mention, entity_negs in incorrect_entries[article_id].items(): |                     for mention, entity_negs in incorrect_entries[article_id].items(): | ||||||
|                         if not balance or pos_entities.get(article_id + "_" + mention): |                         for entity_neg in entity_negs: | ||||||
|                             neg_count = 0 |                             descr = id_to_descr.get(entity_neg) | ||||||
|                             for entity_neg in entity_negs: |                             if descr: | ||||||
|                                 # if balance, keep only 1 negative instance for each positive instance |                                 entities.append(next_entity_nr) | ||||||
|                                 if neg_count < 1 or not balance: |                                 gold_by_entity[next_entity_nr] = 0 | ||||||
|                                     descr = id_to_descr.get(entity_neg) |                                 desc_by_entity[next_entity_nr] = descr | ||||||
|                                     if descr: |                                 article_by_entity[next_entity_nr] = article_id | ||||||
|                                         descr_list = neg_entities.get(article_id + "_" + mention, []) |                                 next_entity_nr += 1 | ||||||
|                                         descr_list.append(descr) |  | ||||||
|                                         neg_entities[article_id + "_" + mention] = descr_list |  | ||||||
|                                         neg_count += 1 |  | ||||||
| 
 | 
 | ||||||
|         if to_print: |         if to_print: | ||||||
|             print() |             print() | ||||||
|             print("Processed", cnt, "training articles, dev=" + str(dev)) |             print("Processed", cnt, "training articles, dev=" + str(dev)) | ||||||
|             print() |             print() | ||||||
|         return instance_by_article, pos_entities, neg_entities, text_by_article |         return entities, gold_by_entity, desc_by_entity, article_by_entity, text_by_article | ||||||
|  | 
 | ||||||
|  |  | ||||||
|  | @ -111,7 +111,7 @@ if __name__ == "__main__": | ||||||
|         print("STEP 6: training", datetime.datetime.now()) |         print("STEP 6: training", datetime.datetime.now()) | ||||||
|         my_nlp = spacy.load('en_core_web_md') |         my_nlp = spacy.load('en_core_web_md') | ||||||
|         trainer = EL_Model(kb=my_kb, nlp=my_nlp) |         trainer = EL_Model(kb=my_kb, nlp=my_nlp) | ||||||
|         trainer.train_model(training_dir=TRAINING_DIR, entity_descr_output=ENTITY_DESCR, trainlimit=100, devlimit=20) |         trainer.train_model(training_dir=TRAINING_DIR, entity_descr_output=ENTITY_DESCR, trainlimit=400, devlimit=50) | ||||||
|         print() |         print() | ||||||
| 
 | 
 | ||||||
|     # STEP 7: apply the EL algorithm on the dev dataset |     # STEP 7: apply the EL algorithm on the dev dataset | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue
	
	Block a user