From 4392c01b7bfb22e435249128ac15c196c5b50bd1 Mon Sep 17 00:00:00 2001 From: svlandeg Date: Thu, 23 May 2019 15:37:05 +0200 Subject: [PATCH] obtain sentence for each mention --- .../pipeline/wiki_entity_linking/run_el.py | 9 +- .../pipeline/wiki_entity_linking/train_el.py | 144 +++++++++++++----- .../wiki_entity_linking/wiki_nel_pipeline.py | 2 +- 3 files changed, 112 insertions(+), 43 deletions(-) diff --git a/examples/pipeline/wiki_entity_linking/run_el.py b/examples/pipeline/wiki_entity_linking/run_el.py index 273543306..c0c219829 100644 --- a/examples/pipeline/wiki_entity_linking/run_el.py +++ b/examples/pipeline/wiki_entity_linking/run_el.py @@ -70,7 +70,7 @@ def is_dev(file_name): return file_name.endswith("3.txt") -def evaluate(predictions, golds, to_print=True): +def evaluate(predictions, golds, to_print=True, times_hundred=True): if len(predictions) != len(golds): raise ValueError("predictions and gold entities should have the same length") @@ -101,8 +101,11 @@ def evaluate(predictions, golds, to_print=True): print("fp", fp) print("fn", fn) - precision = 100 * tp / (tp + fp + 0.0000001) - recall = 100 * tp / (tp + fn + 0.0000001) + precision = tp / (tp + fp + 0.0000001) + recall = tp / (tp + fn + 0.0000001) + if times_hundred: + precision = precision*100 + recall = recall*100 fscore = 2 * recall * precision / (recall + precision + 0.0000001) accuracy = corrects / (corrects + incorrects) diff --git a/examples/pipeline/wiki_entity_linking/train_el.py b/examples/pipeline/wiki_entity_linking/train_el.py index cd6e9de4d..d8082635a 100644 --- a/examples/pipeline/wiki_entity_linking/train_el.py +++ b/examples/pipeline/wiki_entity_linking/train_el.py @@ -20,6 +20,7 @@ from thinc.t2t import ParametricAttention from thinc.misc import Residual from thinc.misc import LayerNorm as LN +from spacy.matcher import PhraseMatcher from spacy.tokens import Doc """ TODO: this code needs to be implemented in pipes.pyx""" @@ -27,13 +28,16 @@ from spacy.tokens import Doc class EL_Model: + PRINT_INSPECT = False PRINT_TRAIN = False EPS = 0.0000000005 CUTOFF = 0.5 BATCH_SIZE = 5 - INPUT_DIM = 300 + DOC_CUTOFF = 300 # number of characters from the doc context + INPUT_DIM = 300 # dimension of pre-trained vectors + HIDDEN_1_WIDTH = 32 # 10 HIDDEN_2_WIDTH = 32 # 6 DESC_WIDTH = 64 # 4 @@ -58,11 +62,20 @@ class EL_Model: # raise errors instead of runtime warnings in case of int/float overflow np.seterr(all='raise') - train_ent, train_gold, train_desc, train_article, train_texts = self._get_training_data(training_dir, - entity_descr_output, - False, - trainlimit, - to_print=False) + train_ent, train_gold, train_desc, train_art, train_art_texts, train_sent, train_sent_texts = \ + self._get_training_data(training_dir, entity_descr_output, False, trainlimit, to_print=False) + + # inspect data + if self.PRINT_INSPECT: + for entity in train_ent: + print("entity", entity) + print("gold", train_gold[entity]) + print("desc", train_desc[entity]) + print("sentence ID", train_sent[entity]) + print("sentence text", train_sent_texts[train_sent[entity]]) + print("article ID", train_art[entity]) + print("article text", train_art_texts[train_art[entity]]) + print() train_pos_entities = [k for k,v in train_gold.items() if v] train_neg_entities = [k for k,v in train_gold.items() if not v] @@ -70,6 +83,10 @@ class EL_Model: train_pos_count = len(train_pos_entities) train_neg_count = len(train_neg_entities) + if to_print: + print() + print("Upsampling, original training instances pos/neg:", train_pos_count, train_neg_count) + # upsample positives to 50-50 distribution while train_pos_count < train_neg_count: train_ent.append(random.choice(train_pos_entities)) @@ -82,11 +99,8 @@ class EL_Model: shuffle(train_ent) - dev_ent, dev_gold, dev_desc, dev_article, dev_texts = self._get_training_data(training_dir, - entity_descr_output, - True, - devlimit, - to_print=False) + dev_ent, dev_gold, dev_desc, dev_art, dev_art_texts, dev_sent, dev_sent_texts = \ + self._get_training_data(training_dir, entity_descr_output, True, devlimit, to_print=False) shuffle(dev_ent) dev_pos_count = len([g for g in dev_gold.values() if g]) @@ -94,20 +108,16 @@ class EL_Model: self._begin_training() - print() - self._test_dev(dev_ent, dev_gold, dev_desc, dev_article, dev_texts, print_string="dev_random", calc_random=True) - print() - self._test_dev(dev_ent, dev_gold, dev_desc, dev_article, dev_texts, print_string="dev_pre", avg=True) - if to_print: print() - print("Training on", len(train_ent), "entities in", len(train_texts), "articles") - print("Training instances pos/neg", train_pos_count, train_neg_count) + print("Training on", len(train_ent), "entities in", len(train_art_texts), "articles") + print("Training instances pos/neg:", train_pos_count, train_neg_count) print() - print("Dev test on", len(dev_ent), "entities in", len(dev_texts), "articles") - print("Dev instances pos/neg", dev_pos_count, dev_neg_count) + print("Dev test on", len(dev_ent), "entities in", len(dev_art_texts), "articles") + print("Dev instances pos/neg:", dev_pos_count, dev_neg_count) print() print(" CUTOFF", self.CUTOFF) + print(" DOC_CUTOFF", self.DOC_CUTOFF) print(" INPUT_DIM", self.INPUT_DIM) print(" HIDDEN_1_WIDTH", self.HIDDEN_1_WIDTH) print(" DESC_WIDTH", self.DESC_WIDTH) @@ -116,6 +126,10 @@ class EL_Model: print(" DROP", self.DROP) print() + self._test_dev(dev_ent, dev_gold, dev_desc, dev_art, dev_art_texts, print_string="dev_random", calc_random=True) + self._test_dev(dev_ent, dev_gold, dev_desc, dev_art, dev_art_texts, print_string="dev_pre", avg=True) + print() + start = 0 stop = min(self.BATCH_SIZE, len(train_ent)) processed = 0 @@ -125,10 +139,10 @@ class EL_Model: golds = [train_gold[e] for e in next_batch] descs = [train_desc[e] for e in next_batch] - articles = [train_texts[train_article[e]] for e in next_batch] + articles = [train_art_texts[train_art[e]] for e in next_batch] self.update(entities=next_batch, golds=golds, descs=descs, texts=articles) - self._test_dev(dev_ent, dev_gold, dev_desc, dev_article, dev_texts, print_string="dev_inter", avg=True) + self._test_dev(dev_ent, dev_gold, dev_desc, dev_art, dev_art_texts, print_string="dev_inter", avg=True) processed += len(next_batch) @@ -151,7 +165,7 @@ class EL_Model: predictions = self._predict(entities=entities, article_docs=article_docs, desc_docs=desc_docs, avg=avg) # TODO: combine with prior probability - p, r, f, acc = run_el.evaluate(predictions, golds, to_print=False) + p, r, f, acc = run_el.evaluate(predictions, golds, to_print=False, times_hundred=False) loss, gradient = self.get_loss(self.model.ops.asarray(predictions), self.model.ops.asarray(golds)) print("p/r/F/acc/loss", print_string, round(p, 1), round(r, 1), round(f, 1), round(acc, 2), round(loss, 5)) @@ -288,14 +302,18 @@ class EL_Model: collect_incorrect=True) local_vectors = list() # TODO: local vectors - text_by_article = dict() + + entities = set() gold_by_entity = dict() desc_by_entity = dict() article_by_entity = dict() - entities = list() + text_by_article = dict() + sentence_by_entity = dict() + text_by_sentence = dict() cnt = 0 - next_entity_nr = 0 + next_entity_nr = 1 + next_sent_nr = 1 files = listdir(training_dir) shuffle(files) for f in files: @@ -305,33 +323,81 @@ class EL_Model: if cnt % 500 == 0 and to_print: print(datetime.datetime.now(), "processed", cnt, "files in the training dataset") cnt += 1 - if article_id not in text_by_article: - with open(os.path.join(training_dir, f), mode="r", encoding='utf8') as file: - text = file.read() - text_by_article[article_id] = text + + # parse the article text + with open(os.path.join(training_dir, f), mode="r", encoding='utf8') as file: + text = file.read() + article_doc = self.nlp(text) + truncated_text = text[0:min(self.DOC_CUTOFF, len(text))] + text_by_article[article_id] = truncated_text + + # process all positive and negative entities, collect all relevant mentions in this article + article_terms = set() + entities_by_mention = dict() for mention, entity_pos in correct_entries[article_id].items(): descr = id_to_descr.get(entity_pos) if descr: - entities.append(next_entity_nr) - gold_by_entity[next_entity_nr] = 1 - desc_by_entity[next_entity_nr] = descr - article_by_entity[next_entity_nr] = article_id + entity = "E_" + str(next_entity_nr) + "_" + article_id + "_" + mention next_entity_nr += 1 + gold_by_entity[entity] = 1 + desc_by_entity[entity] = descr + article_terms.add(mention) + mention_entities = entities_by_mention.get(mention, set()) + mention_entities.add(entity) + entities_by_mention[mention] = mention_entities for mention, entity_negs in incorrect_entries[article_id].items(): for entity_neg in entity_negs: descr = id_to_descr.get(entity_neg) if descr: - entities.append(next_entity_nr) - gold_by_entity[next_entity_nr] = 0 - desc_by_entity[next_entity_nr] = descr - article_by_entity[next_entity_nr] = article_id + entity = "E_" + str(next_entity_nr) + "_" + article_id + "_" + mention next_entity_nr += 1 + gold_by_entity[entity] = 0 + desc_by_entity[entity] = descr + article_terms.add(mention) + mention_entities = entities_by_mention.get(mention, set()) + mention_entities.add(entity) + entities_by_mention[mention] = mention_entities + + # find all matches in the doc for the mentions + # TODO: fix this - doesn't look like all entities are found + matcher = PhraseMatcher(self.nlp.vocab) + patterns = list(self.nlp.tokenizer.pipe(article_terms)) + + matcher.add("TerminologyList", None, *patterns) + matches = matcher(article_doc) + + # store sentences + sentence_to_id = dict() + for match_id, start, end in matches: + span = article_doc[start:end] + sent_text = span.sent + sent_nr = sentence_to_id.get(sent_text, None) + if sent_nr is None: + sent_nr = "S_" + str(next_sent_nr) + article_id + next_sent_nr += 1 + text_by_sentence[sent_nr] = sent_text + sentence_to_id[sent_text] = sent_nr + mention_entities = entities_by_mention[span.text] + for entity in mention_entities: + entities.add(entity) + sentence_by_entity[entity] = sent_nr + article_by_entity[entity] = article_id + + # remove entities that didn't have all data + gold_by_entity = {k: v for k, v in gold_by_entity.items() if k in entities} + desc_by_entity = {k: v for k, v in desc_by_entity.items() if k in entities} + + article_by_entity = {k: v for k, v in article_by_entity.items() if k in entities} + text_by_article = {k: v for k, v in text_by_article.items() if k in article_by_entity.values()} + + sentence_by_entity = {k: v for k, v in sentence_by_entity.items() if k in entities} + text_by_sentence = {k: v for k, v in text_by_sentence.items() if k in sentence_by_entity.values()} if to_print: print() print("Processed", cnt, "training articles, dev=" + str(dev)) print() - return entities, gold_by_entity, desc_by_entity, article_by_entity, text_by_article + return list(entities), gold_by_entity, desc_by_entity, article_by_entity, text_by_article, sentence_by_entity, text_by_sentence diff --git a/examples/pipeline/wiki_entity_linking/wiki_nel_pipeline.py b/examples/pipeline/wiki_entity_linking/wiki_nel_pipeline.py index 715282642..319b1e1c8 100644 --- a/examples/pipeline/wiki_entity_linking/wiki_nel_pipeline.py +++ b/examples/pipeline/wiki_entity_linking/wiki_nel_pipeline.py @@ -111,7 +111,7 @@ if __name__ == "__main__": print("STEP 6: training", datetime.datetime.now()) my_nlp = spacy.load('en_core_web_md') trainer = EL_Model(kb=my_kb, nlp=my_nlp) - trainer.train_model(training_dir=TRAINING_DIR, entity_descr_output=ENTITY_DESCR, trainlimit=400, devlimit=50) + trainer.train_model(training_dir=TRAINING_DIR, entity_descr_output=ENTITY_DESCR, trainlimit=100, devlimit=20) print() # STEP 7: apply the EL algorithm on the dev dataset