eval on dev set, varying combo's of prior and context scores

This commit is contained in:
svlandeg 2019-06-11 11:40:58 +02:00
parent 83dc7b46fd
commit fe1ed432ef
4 changed files with 127 additions and 70 deletions

View File

@ -20,7 +20,7 @@ def create_kb(nlp, max_entities_per_alias, min_occ,
""" Create the knowledge base from Wikidata entries """
kb = KnowledgeBase(vocab=nlp.vocab, entity_vector_length=DESC_WIDTH)
# disable parts of the pipeline when rerunning
# disable this part of the pipeline when rerunning the KB generation from preprocessed files
read_raw_data = False
if read_raw_data:

View File

@ -21,29 +21,10 @@ def run_kb_toy_example(kb):
print(" ", c.prior_prob, c.alias_, "-->", c.entity_ + " (freq=" + str(c.entity_freq) + ")")
print()
def run_el_toy_example(nlp, kb):
_prepare_pipeline(nlp, kb)
candidates = kb.get_candidates("Bush")
print("generating candidates for 'Bush' :")
for c in candidates:
print(" ", c.prior_prob, c.alias_, "-->", c.entity_ + " (freq=" + str(c.entity_freq) + ")")
print()
text = "In The Hitchhiker's Guide to the Galaxy, written by Douglas Adams, " \
"Douglas reminds us to always bring our towel. " \
"The main character in Doug's novel is the man Arthur Dent, " \
"but Douglas doesn't write about George Washington or Homer Simpson."
doc = nlp(text)
for ent in doc.ents:
print("ent", ent.text, ent.label_, ent.kb_id_)
def run_el_dev(nlp, kb, training_dir, limit=None):
_prepare_pipeline(nlp, kb)
correct_entries_per_article, _ = training_set_creator.read_training_entities(training_output=training_dir,
collect_correct=True,
collect_incorrect=False)

View File

@ -6,7 +6,6 @@ import random
from spacy.util import minibatch, compounding
from examples.pipeline.wiki_entity_linking import wikipedia_processor as wp, kb_creator, training_set_creator, run_el
from examples.pipeline.wiki_entity_linking.train_el import EL_Model
import spacy
from spacy.vocab import Vocab
@ -30,10 +29,11 @@ TRAINING_DIR = 'C:/Users/Sofie/Documents/data/wikipedia/training_data_nel/'
MAX_CANDIDATES = 10
MIN_PAIR_OCC = 5
DOC_CHAR_CUTOFF = 300
EPOCHS = 5
EPOCHS = 10
DROPOUT = 0.1
if __name__ == "__main__":
def run_pipeline():
print("START", datetime.datetime.now())
print()
nlp = spacy.load('en_core_web_lg')
@ -51,15 +51,11 @@ if __name__ == "__main__":
# create training dataset
create_wp_training = False
# train the EL pipe
train_pipe = True
# run EL training
run_el_training = False
# apply named entity linking to the dev dataset
apply_to_dev = False
to_test_pipeline = False
# test the EL pipe on a simple example
to_test_pipeline = True
# STEP 1 : create prior probabilities from WP
# run only once !
@ -119,10 +115,11 @@ if __name__ == "__main__":
# STEP 6: create the entity linking pipe
if train_pipe:
id_to_descr = kb_creator._get_id_to_description(ENTITY_DESCR)
train_limit = 10
train_limit = 5
dev_limit = 2
print("Training on", train_limit, "articles")
print("Dev testing on", dev_limit, "articles")
print()
train_data = training_set_creator.read_training(nlp=nlp,
training_dir=TRAINING_DIR,
@ -130,6 +127,12 @@ if __name__ == "__main__":
limit=train_limit,
to_print=False)
dev_data = training_set_creator.read_training(nlp=nlp,
training_dir=TRAINING_DIR,
dev=True,
limit=dev_limit,
to_print=False)
el_pipe = nlp.create_pipe(name='entity_linker', config={"kb": my_kb, "doc_cutoff": DOC_CHAR_CUTOFF})
nlp.add_pipe(el_pipe, last=True)
@ -137,12 +140,12 @@ if __name__ == "__main__":
with nlp.disable_pipes(*other_pipes): # only train Entity Linking
nlp.begin_training()
for itn in range(EPOCHS):
print()
print("EPOCH", itn)
random.shuffle(train_data)
losses = {}
batches = minibatch(train_data, size=compounding(4.0, 128.0, 1.001))
for itn in range(EPOCHS):
random.shuffle(train_data)
losses = {}
batches = minibatch(train_data, size=compounding(4.0, 128.0, 1.001))
with nlp.disable_pipes(*other_pipes):
for batch in batches:
docs, golds = zip(*batch)
nlp.update(
@ -151,20 +154,89 @@ if __name__ == "__main__":
drop=DROPOUT,
losses=losses,
)
print("Losses", losses)
# STEP 7: apply the EL algorithm on the dev dataset (TODO: overlaps with code from run_el_training ?)
if apply_to_dev:
run_el.run_el_dev(kb=my_kb, nlp=nlp, training_dir=TRAINING_DIR, limit=2000)
print()
el_pipe.context_weight = 1
el_pipe.prior_weight = 1
dev_acc_1_1 = _measure_accuracy(dev_data, nlp)
train_acc_1_1 = _measure_accuracy(train_data, nlp)
# test KB
el_pipe.context_weight = 0
el_pipe.prior_weight = 1
dev_acc_0_1 = _measure_accuracy(dev_data, nlp)
train_acc_0_1 = _measure_accuracy(train_data, nlp)
el_pipe.context_weight = 1
el_pipe.prior_weight = 0
dev_acc_1_0 = _measure_accuracy(dev_data, nlp)
train_acc_1_0 = _measure_accuracy(train_data, nlp)
print("Epoch, train loss, train/dev acc, 1-1, 0-1, 1-0:", itn, losses['entity_linker'],
round(train_acc_1_1, 2), round(train_acc_0_1, 2), round(train_acc_1_0, 2), "/",
round(dev_acc_1_1, 2), round(dev_acc_0_1, 2), round(dev_acc_1_0, 2))
# test Entity Linker
if to_test_pipeline:
run_el.run_el_toy_example(kb=my_kb, nlp=nlp)
print()
# TODO coreference resolution
# add_coref()
run_el_toy_example(kb=my_kb, nlp=nlp)
print()
print()
print("STOP", datetime.datetime.now())
def _measure_accuracy(data, nlp):
correct = 0
incorrect = 0
texts = [d.text for d, g in data]
docs = list(nlp.pipe(texts))
golds = [g for d, g in data]
for doc, gold in zip(docs, golds):
correct_entries_per_article = dict()
for entity in gold.links:
start, end, gold_kb = entity
correct_entries_per_article[str(start) + "-" + str(end)] = gold_kb
for ent in doc.ents:
if ent.label_ == "PERSON": # TODO: expand to other types
pred_entity = ent.kb_id_
start = ent.start
end = ent.end
gold_entity = correct_entries_per_article.get(str(start) + "-" + str(end), None)
if gold_entity is not None:
if gold_entity == pred_entity:
correct += 1
else:
incorrect += 1
if correct == incorrect == 0:
return 0
acc = correct / (correct + incorrect)
return acc
def run_el_toy_example(nlp, kb):
text = "In The Hitchhiker's Guide to the Galaxy, written by Douglas Adams, " \
"Douglas reminds us to always bring our towel. " \
"The main character in Doug's novel is the man Arthur Dent, " \
"but Douglas doesn't write about George Washington or Homer Simpson."
doc = nlp(text)
for ent in doc.ents:
print("ent", ent.text, ent.label_, ent.kb_id_)
print()
# Q4426480 is her husband, Q3568763 her tutor
text = "Ada Lovelace loved her husband William King dearly. " \
"Ada Lovelace was tutored by her favorite physics tutor William King."
doc = nlp(text)
for ent in doc.ents:
print("ent", ent.text, ent.label_, ent.kb_id_)
if __name__ == "__main__":
run_pipeline()

View File

@ -1068,6 +1068,8 @@ class EntityLinker(Pipe):
DOCS: TODO
"""
name = 'entity_linker'
context_weight = 1
prior_weight = 1
@classmethod
def Model(cls, **cfg):
@ -1093,14 +1095,15 @@ class EntityLinker(Pipe):
self.doc_cutoff = self.cfg["doc_cutoff"]
def use_avg_params(self):
"""Modify the pipe's encoders/models, to use their average parameter values."""
with self.article_encoder.use_params(self.sgd_article.averages) \
and self.sent_encoder.use_params(self.sgd_sent.averages) \
and self.mention_encoder.use_params(self.sgd_mention.averages):
yield
# Modify the pipe's encoders/models, to use their average parameter values.
# TODO: this doesn't work yet because there's no exit method
self.article_encoder.use_params(self.sgd_article.averages)
self.sent_encoder.use_params(self.sgd_sent.averages)
self.mention_encoder.use_params(self.sgd_mention.averages)
def require_model(self):
"""Raise an error if the component's model is not initialized."""
# Raise an error if the component's model is not initialized.
if getattr(self, "mention_encoder", None) in (None, True, False):
raise ValueError(Errors.E109.format(name=self.name))
@ -1110,6 +1113,7 @@ class EntityLinker(Pipe):
self.sgd_article = create_default_optimizer(self.article_encoder.ops)
self.sgd_sent = create_default_optimizer(self.sent_encoder.ops)
self.sgd_mention = create_default_optimizer(self.mention_encoder.ops)
return self.sgd_article
def update(self, docs, golds, state=None, drop=0.0, sgd=None, losses=None):
self.require_model()
@ -1229,27 +1233,27 @@ class EntityLinker(Pipe):
candidates = self.kb.get_candidates(ent.text)
if candidates:
with self.use_avg_params:
scores = list()
for c in candidates:
prior_prob = c.prior_prob
kb_id = c.entity_
entity_encoding = c.entity_vector
sim = cosine([entity_encoding], mention_enc_t)
score = prior_prob + sim - (prior_prob*sim) # put weights on the different factors ?
scores.append(score)
scores = list()
for c in candidates:
prior_prob = c.prior_prob * self.prior_weight
kb_id = c.entity_
entity_encoding = c.entity_vector
sim = cosine(np.asarray([entity_encoding]), mention_enc_t) * self.context_weight
score = prior_prob + sim - (prior_prob*sim) # put weights on the different factors ?
scores.append(score)
# TODO: thresholding
best_index = scores.index(max(scores))
best_candidate = candidates[best_index]
final_entities.append(ent)
final_kb_ids.append(best_candidate)
# TODO: thresholding
best_index = scores.index(max(scores))
best_candidate = candidates[best_index]
final_entities.append(ent)
final_kb_ids.append(best_candidate.entity_)
return final_entities, final_kb_ids
def set_annotations(self, docs, entities, kb_ids=None):
for entity, kb_id in zip(entities, kb_ids):
entity.ent_kb_id_ = kb_id
for token in entity:
token.ent_kb_id_ = kb_id
class Sentencizer(object):
"""Segment the Doc into sentences using a rule-based strategy.