obtain sentence for each mention

This commit is contained in:
svlandeg 2019-05-23 15:37:05 +02:00
parent 97241a3ed7
commit 4392c01b7b
3 changed files with 112 additions and 43 deletions

View File

@ -70,7 +70,7 @@ def is_dev(file_name):
return file_name.endswith("3.txt") return file_name.endswith("3.txt")
def evaluate(predictions, golds, to_print=True): def evaluate(predictions, golds, to_print=True, times_hundred=True):
if len(predictions) != len(golds): if len(predictions) != len(golds):
raise ValueError("predictions and gold entities should have the same length") raise ValueError("predictions and gold entities should have the same length")
@ -101,8 +101,11 @@ def evaluate(predictions, golds, to_print=True):
print("fp", fp) print("fp", fp)
print("fn", fn) print("fn", fn)
precision = 100 * tp / (tp + fp + 0.0000001) precision = tp / (tp + fp + 0.0000001)
recall = 100 * tp / (tp + fn + 0.0000001) recall = tp / (tp + fn + 0.0000001)
if times_hundred:
precision = precision*100
recall = recall*100
fscore = 2 * recall * precision / (recall + precision + 0.0000001) fscore = 2 * recall * precision / (recall + precision + 0.0000001)
accuracy = corrects / (corrects + incorrects) accuracy = corrects / (corrects + incorrects)

View File

@ -20,6 +20,7 @@ from thinc.t2t import ParametricAttention
from thinc.misc import Residual from thinc.misc import Residual
from thinc.misc import LayerNorm as LN from thinc.misc import LayerNorm as LN
from spacy.matcher import PhraseMatcher
from spacy.tokens import Doc from spacy.tokens import Doc
""" TODO: this code needs to be implemented in pipes.pyx""" """ TODO: this code needs to be implemented in pipes.pyx"""
@ -27,13 +28,16 @@ from spacy.tokens import Doc
class EL_Model: class EL_Model:
PRINT_INSPECT = False
PRINT_TRAIN = False PRINT_TRAIN = False
EPS = 0.0000000005 EPS = 0.0000000005
CUTOFF = 0.5 CUTOFF = 0.5
BATCH_SIZE = 5 BATCH_SIZE = 5
INPUT_DIM = 300 DOC_CUTOFF = 300 # number of characters from the doc context
INPUT_DIM = 300 # dimension of pre-trained vectors
HIDDEN_1_WIDTH = 32 # 10 HIDDEN_1_WIDTH = 32 # 10
HIDDEN_2_WIDTH = 32 # 6 HIDDEN_2_WIDTH = 32 # 6
DESC_WIDTH = 64 # 4 DESC_WIDTH = 64 # 4
@ -58,11 +62,20 @@ class EL_Model:
# raise errors instead of runtime warnings in case of int/float overflow # raise errors instead of runtime warnings in case of int/float overflow
np.seterr(all='raise') np.seterr(all='raise')
train_ent, train_gold, train_desc, train_article, train_texts = self._get_training_data(training_dir, train_ent, train_gold, train_desc, train_art, train_art_texts, train_sent, train_sent_texts = \
entity_descr_output, self._get_training_data(training_dir, entity_descr_output, False, trainlimit, to_print=False)
False,
trainlimit, # inspect data
to_print=False) if self.PRINT_INSPECT:
for entity in train_ent:
print("entity", entity)
print("gold", train_gold[entity])
print("desc", train_desc[entity])
print("sentence ID", train_sent[entity])
print("sentence text", train_sent_texts[train_sent[entity]])
print("article ID", train_art[entity])
print("article text", train_art_texts[train_art[entity]])
print()
train_pos_entities = [k for k,v in train_gold.items() if v] train_pos_entities = [k for k,v in train_gold.items() if v]
train_neg_entities = [k for k,v in train_gold.items() if not v] train_neg_entities = [k for k,v in train_gold.items() if not v]
@ -70,6 +83,10 @@ class EL_Model:
train_pos_count = len(train_pos_entities) train_pos_count = len(train_pos_entities)
train_neg_count = len(train_neg_entities) train_neg_count = len(train_neg_entities)
if to_print:
print()
print("Upsampling, original training instances pos/neg:", train_pos_count, train_neg_count)
# upsample positives to 50-50 distribution # upsample positives to 50-50 distribution
while train_pos_count < train_neg_count: while train_pos_count < train_neg_count:
train_ent.append(random.choice(train_pos_entities)) train_ent.append(random.choice(train_pos_entities))
@ -82,11 +99,8 @@ class EL_Model:
shuffle(train_ent) shuffle(train_ent)
dev_ent, dev_gold, dev_desc, dev_article, dev_texts = self._get_training_data(training_dir, dev_ent, dev_gold, dev_desc, dev_art, dev_art_texts, dev_sent, dev_sent_texts = \
entity_descr_output, self._get_training_data(training_dir, entity_descr_output, True, devlimit, to_print=False)
True,
devlimit,
to_print=False)
shuffle(dev_ent) shuffle(dev_ent)
dev_pos_count = len([g for g in dev_gold.values() if g]) dev_pos_count = len([g for g in dev_gold.values() if g])
@ -94,20 +108,16 @@ class EL_Model:
self._begin_training() self._begin_training()
print()
self._test_dev(dev_ent, dev_gold, dev_desc, dev_article, dev_texts, print_string="dev_random", calc_random=True)
print()
self._test_dev(dev_ent, dev_gold, dev_desc, dev_article, dev_texts, print_string="dev_pre", avg=True)
if to_print: if to_print:
print() print()
print("Training on", len(train_ent), "entities in", len(train_texts), "articles") print("Training on", len(train_ent), "entities in", len(train_art_texts), "articles")
print("Training instances pos/neg", train_pos_count, train_neg_count) print("Training instances pos/neg:", train_pos_count, train_neg_count)
print() print()
print("Dev test on", len(dev_ent), "entities in", len(dev_texts), "articles") print("Dev test on", len(dev_ent), "entities in", len(dev_art_texts), "articles")
print("Dev instances pos/neg", dev_pos_count, dev_neg_count) print("Dev instances pos/neg:", dev_pos_count, dev_neg_count)
print() print()
print(" CUTOFF", self.CUTOFF) print(" CUTOFF", self.CUTOFF)
print(" DOC_CUTOFF", self.DOC_CUTOFF)
print(" INPUT_DIM", self.INPUT_DIM) print(" INPUT_DIM", self.INPUT_DIM)
print(" HIDDEN_1_WIDTH", self.HIDDEN_1_WIDTH) print(" HIDDEN_1_WIDTH", self.HIDDEN_1_WIDTH)
print(" DESC_WIDTH", self.DESC_WIDTH) print(" DESC_WIDTH", self.DESC_WIDTH)
@ -116,6 +126,10 @@ class EL_Model:
print(" DROP", self.DROP) print(" DROP", self.DROP)
print() print()
self._test_dev(dev_ent, dev_gold, dev_desc, dev_art, dev_art_texts, print_string="dev_random", calc_random=True)
self._test_dev(dev_ent, dev_gold, dev_desc, dev_art, dev_art_texts, print_string="dev_pre", avg=True)
print()
start = 0 start = 0
stop = min(self.BATCH_SIZE, len(train_ent)) stop = min(self.BATCH_SIZE, len(train_ent))
processed = 0 processed = 0
@ -125,10 +139,10 @@ class EL_Model:
golds = [train_gold[e] for e in next_batch] golds = [train_gold[e] for e in next_batch]
descs = [train_desc[e] for e in next_batch] descs = [train_desc[e] for e in next_batch]
articles = [train_texts[train_article[e]] for e in next_batch] articles = [train_art_texts[train_art[e]] for e in next_batch]
self.update(entities=next_batch, golds=golds, descs=descs, texts=articles) self.update(entities=next_batch, golds=golds, descs=descs, texts=articles)
self._test_dev(dev_ent, dev_gold, dev_desc, dev_article, dev_texts, print_string="dev_inter", avg=True) self._test_dev(dev_ent, dev_gold, dev_desc, dev_art, dev_art_texts, print_string="dev_inter", avg=True)
processed += len(next_batch) processed += len(next_batch)
@ -151,7 +165,7 @@ class EL_Model:
predictions = self._predict(entities=entities, article_docs=article_docs, desc_docs=desc_docs, avg=avg) predictions = self._predict(entities=entities, article_docs=article_docs, desc_docs=desc_docs, avg=avg)
# TODO: combine with prior probability # TODO: combine with prior probability
p, r, f, acc = run_el.evaluate(predictions, golds, to_print=False) p, r, f, acc = run_el.evaluate(predictions, golds, to_print=False, times_hundred=False)
loss, gradient = self.get_loss(self.model.ops.asarray(predictions), self.model.ops.asarray(golds)) loss, gradient = self.get_loss(self.model.ops.asarray(predictions), self.model.ops.asarray(golds))
print("p/r/F/acc/loss", print_string, round(p, 1), round(r, 1), round(f, 1), round(acc, 2), round(loss, 5)) print("p/r/F/acc/loss", print_string, round(p, 1), round(r, 1), round(f, 1), round(acc, 2), round(loss, 5))
@ -288,14 +302,18 @@ class EL_Model:
collect_incorrect=True) collect_incorrect=True)
local_vectors = list() # TODO: local vectors local_vectors = list() # TODO: local vectors
text_by_article = dict()
entities = set()
gold_by_entity = dict() gold_by_entity = dict()
desc_by_entity = dict() desc_by_entity = dict()
article_by_entity = dict() article_by_entity = dict()
entities = list() text_by_article = dict()
sentence_by_entity = dict()
text_by_sentence = dict()
cnt = 0 cnt = 0
next_entity_nr = 0 next_entity_nr = 1
next_sent_nr = 1
files = listdir(training_dir) files = listdir(training_dir)
shuffle(files) shuffle(files)
for f in files: for f in files:
@ -305,33 +323,81 @@ class EL_Model:
if cnt % 500 == 0 and to_print: if cnt % 500 == 0 and to_print:
print(datetime.datetime.now(), "processed", cnt, "files in the training dataset") print(datetime.datetime.now(), "processed", cnt, "files in the training dataset")
cnt += 1 cnt += 1
if article_id not in text_by_article:
with open(os.path.join(training_dir, f), mode="r", encoding='utf8') as file: # parse the article text
text = file.read() with open(os.path.join(training_dir, f), mode="r", encoding='utf8') as file:
text_by_article[article_id] = text text = file.read()
article_doc = self.nlp(text)
truncated_text = text[0:min(self.DOC_CUTOFF, len(text))]
text_by_article[article_id] = truncated_text
# process all positive and negative entities, collect all relevant mentions in this article
article_terms = set()
entities_by_mention = dict()
for mention, entity_pos in correct_entries[article_id].items(): for mention, entity_pos in correct_entries[article_id].items():
descr = id_to_descr.get(entity_pos) descr = id_to_descr.get(entity_pos)
if descr: if descr:
entities.append(next_entity_nr) entity = "E_" + str(next_entity_nr) + "_" + article_id + "_" + mention
gold_by_entity[next_entity_nr] = 1
desc_by_entity[next_entity_nr] = descr
article_by_entity[next_entity_nr] = article_id
next_entity_nr += 1 next_entity_nr += 1
gold_by_entity[entity] = 1
desc_by_entity[entity] = descr
article_terms.add(mention)
mention_entities = entities_by_mention.get(mention, set())
mention_entities.add(entity)
entities_by_mention[mention] = mention_entities
for mention, entity_negs in incorrect_entries[article_id].items(): for mention, entity_negs in incorrect_entries[article_id].items():
for entity_neg in entity_negs: for entity_neg in entity_negs:
descr = id_to_descr.get(entity_neg) descr = id_to_descr.get(entity_neg)
if descr: if descr:
entities.append(next_entity_nr) entity = "E_" + str(next_entity_nr) + "_" + article_id + "_" + mention
gold_by_entity[next_entity_nr] = 0
desc_by_entity[next_entity_nr] = descr
article_by_entity[next_entity_nr] = article_id
next_entity_nr += 1 next_entity_nr += 1
gold_by_entity[entity] = 0
desc_by_entity[entity] = descr
article_terms.add(mention)
mention_entities = entities_by_mention.get(mention, set())
mention_entities.add(entity)
entities_by_mention[mention] = mention_entities
# find all matches in the doc for the mentions
# TODO: fix this - doesn't look like all entities are found
matcher = PhraseMatcher(self.nlp.vocab)
patterns = list(self.nlp.tokenizer.pipe(article_terms))
matcher.add("TerminologyList", None, *patterns)
matches = matcher(article_doc)
# store sentences
sentence_to_id = dict()
for match_id, start, end in matches:
span = article_doc[start:end]
sent_text = span.sent
sent_nr = sentence_to_id.get(sent_text, None)
if sent_nr is None:
sent_nr = "S_" + str(next_sent_nr) + article_id
next_sent_nr += 1
text_by_sentence[sent_nr] = sent_text
sentence_to_id[sent_text] = sent_nr
mention_entities = entities_by_mention[span.text]
for entity in mention_entities:
entities.add(entity)
sentence_by_entity[entity] = sent_nr
article_by_entity[entity] = article_id
# remove entities that didn't have all data
gold_by_entity = {k: v for k, v in gold_by_entity.items() if k in entities}
desc_by_entity = {k: v for k, v in desc_by_entity.items() if k in entities}
article_by_entity = {k: v for k, v in article_by_entity.items() if k in entities}
text_by_article = {k: v for k, v in text_by_article.items() if k in article_by_entity.values()}
sentence_by_entity = {k: v for k, v in sentence_by_entity.items() if k in entities}
text_by_sentence = {k: v for k, v in text_by_sentence.items() if k in sentence_by_entity.values()}
if to_print: if to_print:
print() print()
print("Processed", cnt, "training articles, dev=" + str(dev)) print("Processed", cnt, "training articles, dev=" + str(dev))
print() print()
return entities, gold_by_entity, desc_by_entity, article_by_entity, text_by_article return list(entities), gold_by_entity, desc_by_entity, article_by_entity, text_by_article, sentence_by_entity, text_by_sentence

View File

@ -111,7 +111,7 @@ if __name__ == "__main__":
print("STEP 6: training", datetime.datetime.now()) print("STEP 6: training", datetime.datetime.now())
my_nlp = spacy.load('en_core_web_md') my_nlp = spacy.load('en_core_web_md')
trainer = EL_Model(kb=my_kb, nlp=my_nlp) trainer = EL_Model(kb=my_kb, nlp=my_nlp)
trainer.train_model(training_dir=TRAINING_DIR, entity_descr_output=ENTITY_DESCR, trainlimit=400, devlimit=50) trainer.train_model(training_dir=TRAINING_DIR, entity_descr_output=ENTITY_DESCR, trainlimit=100, devlimit=20)
print() print()
# STEP 7: apply the EL algorithm on the dev dataset # STEP 7: apply the EL algorithm on the dev dataset