baseline performances: oracle KB, random and prior prob

This commit is contained in:
svlandeg 2019-06-17 14:39:40 +02:00
parent 24db1392b9
commit 6332af40de
2 changed files with 161 additions and 98 deletions

View File

@ -5,11 +5,8 @@ import os
import re
import bz2
import datetime
from os import listdir
from examples.pipeline.wiki_entity_linking import run_el
from spacy.gold import GoldParse
from spacy.matcher import PhraseMatcher
from . import wikipedia_processor as wp, kb_creator
"""
@ -17,7 +14,7 @@ Process Wikipedia interlinks to generate a training dataset for the EL algorithm
"""
# ENTITY_FILE = "gold_entities.csv"
ENTITY_FILE = "gold_entities_100000.csv" # use this file for faster processing
ENTITY_FILE = "gold_entities_1000000.csv" # use this file for faster processing
def create_training(entity_def_input, training_output):
@ -58,7 +55,6 @@ def _process_wikipedia_texts(wp_to_id, training_output, limit=None):
if cnt % 1000000 == 0:
print(datetime.datetime.now(), "processed", cnt, "lines of Wikipedia dump")
clean_line = line.strip().decode("utf-8")
# print(clean_line)
if clean_line == "<revision>":
reading_revision = True
@ -121,7 +117,6 @@ text_regex = re.compile(r'(?<=<text xml:space=\"preserve\">).*(?=</text)')
def _process_wp_text(wp_to_id, entityfile, article_id, article_title, article_text, training_output):
found_entities = False
# print("Processing", article_id, article_title)
# ignore meta Wikipedia pages
if article_title.startswith("Wikipedia:"):
@ -134,13 +129,8 @@ def _process_wp_text(wp_to_id, entityfile, article_id, article_title, article_te
if text.startswith("#REDIRECT"):
return
# print()
# print(text)
# get the raw text without markup etc, keeping only interwiki links
clean_text = _get_clean_wp_text(text)
# print()
# print(clean_text)
# read the text char by char to get the right offsets of the interwiki links
final_text = ""
@ -295,68 +285,62 @@ def is_dev(article_id):
return article_id.endswith("3")
def read_training_entities(training_output, dev, limit):
entityfile_loc = training_output + "/" + ENTITY_FILE
entries_per_article = dict()
article_ids = set()
with open(entityfile_loc, mode='r', encoding='utf8') as file:
for line in file:
if not limit or len(article_ids) < limit:
fields = line.replace('\n', "").split(sep='|')
article_id = fields[0]
if dev == is_dev(article_id) and article_id != "article_id":
article_ids.add(article_id)
alias = fields[1]
wp_title = fields[2]
start = fields[3]
end = fields[4]
entries_by_offset = entries_per_article.get(article_id, dict())
entries_by_offset[start + "-" + end] = (alias, wp_title)
entries_per_article[article_id] = entries_by_offset
return entries_per_article
def read_training(nlp, training_dir, dev, limit):
# This method provides training examples that correspond to the entity annotations found by the nlp object
print("reading training entities")
entries_per_article = read_training_entities(training_output=training_dir, dev=dev, limit=limit)
print("done reading training entities")
entityfile_loc = training_dir + "/" + ENTITY_FILE
data = []
for article_id, entries_by_offset in entries_per_article.items():
file_name = article_id + ".txt"
try:
# parse the article text
with open(os.path.join(training_dir, file_name), mode="r", encoding='utf8') as file:
text = file.read()
article_doc = nlp(text)
gold_entities = list()
for ent in article_doc.ents:
start = ent.start_char
end = ent.end_char
# we assume the data is written sequentially
current_article_id = None
current_doc = None
gold_entities = list()
ents_by_offset = dict()
skip_articles = set()
total_entities = 0
entity_tuple = entries_by_offset.get(str(start) + "-" + str(end), None)
if entity_tuple:
alias, wp_title = entity_tuple
if ent.text != alias:
print("Non-matching entity in", article_id, start, end)
else:
gold_entities.append((start, end, wp_title))
with open(entityfile_loc, mode='r', encoding='utf8') as file:
for line in file:
if not limit or len(data) < limit:
if len(data) > 0 and len(data) % 50 == 0:
print("Read", total_entities, "entities in", len(data), "articles")
fields = line.replace('\n', "").split(sep='|')
article_id = fields[0]
alias = fields[1]
wp_title = fields[2]
start = fields[3]
end = fields[4]
if gold_entities:
gold = GoldParse(doc=article_doc, links=gold_entities)
data.append((article_doc, gold))
if dev == is_dev(article_id) and article_id != "article_id" and article_id not in skip_articles:
if not current_doc or (current_article_id != article_id):
# store the data from the previous article
if gold_entities and current_doc:
gold = GoldParse(doc=current_doc, links=gold_entities)
data.append((current_doc, gold))
total_entities += len(gold_entities)
except Exception as e:
print("Problem parsing article", article_id)
print(e)
raise e
# parse the new article text
file_name = article_id + ".txt"
try:
with open(os.path.join(training_dir, file_name), mode="r", encoding='utf8') as f:
text = f.read()
current_doc = nlp(text)
for ent in current_doc.ents:
ents_by_offset[str(ent.start_char) + "_" + str(ent.end_char)] = ent.text
except Exception as e:
print("Problem parsing article", article_id, e)
current_article_id = article_id
gold_entities = list()
# repeat checking this condition in case an exception was thrown
if current_doc and (current_article_id == article_id):
found_ent = ents_by_offset.get(start + "_" + end, None)
if found_ent:
if found_ent != alias:
skip_articles.add(current_article_id)
else:
gold_entities.append((int(start), int(end), wp_title))
print("Read", total_entities, "entities in", len(data), "articles")
return data

View File

@ -64,7 +64,8 @@ def run_pipeline():
to_test_pipeline = True
# write the NLP object, read back in and test again
test_nlp_io = False
to_write_nlp = True
to_read_nlp = True
# STEP 1 : create prior probabilities from WP
# run only once !
@ -133,7 +134,7 @@ def run_pipeline():
if train_pipe:
print("STEP 6: training Entity Linking pipe", datetime.datetime.now())
train_limit = 10
train_limit = 5
dev_limit = 2
train_data = training_set_creator.read_training(nlp=nlp_2,
@ -166,46 +167,42 @@ def run_pipeline():
)
batchnr += 1
except Exception as e:
print("Error updating batch", e)
print("Error updating batch:", e)
raise(e)
losses['entity_linker'] = losses['entity_linker'] / batchnr
print("Epoch, train loss", itn, round(losses['entity_linker'], 2))
if batchnr > 0:
losses['entity_linker'] = losses['entity_linker'] / batchnr
print("Epoch, train loss", itn, round(losses['entity_linker'], 2))
dev_data = training_set_creator.read_training(nlp=nlp_2,
training_dir=TRAINING_DIR,
dev=True,
limit=dev_limit)
print("Dev testing on", len(dev_data), "articles")
print()
print("Dev testing on", len(dev_data), "articles")
if len(dev_data) and measure_performance:
print()
print("STEP 7: performance measurement of Entity Linking pipe", datetime.datetime.now())
print()
acc_random, acc_random_by_label, acc_prior, acc_prior_by_label, acc_oracle, acc_oracle_by_label = _measure_baselines(dev_data, kb_2)
print("dev acc oracle:", round(acc_oracle, 3), [(x, round(y, 3)) for x, y in acc_oracle_by_label.items()])
print("dev acc random:", round(acc_random, 3), [(x, round(y, 3)) for x, y in acc_random_by_label.items()])
print("dev acc prior:", round(acc_prior, 3), [(x, round(y, 3)) for x, y in acc_prior_by_label.items()])
# print(" measuring accuracy 1-1")
el_pipe.context_weight = 1
el_pipe.prior_weight = 1
dev_acc_1_1, dev_acc_1_1_dict = _measure_accuracy(dev_data, el_pipe)
print("dev acc combo:", round(dev_acc_1_1, 3), [(x, round(y, 3)) for x, y in dev_acc_1_1_dict.items()])
train_acc_1_1, train_acc_1_1_dict = _measure_accuracy(train_data, el_pipe)
print("train acc combo:", round(train_acc_1_1, 3), [(x, round(y, 3)) for x, y in train_acc_1_1_dict.items()])
# baseline using only prior probabilities
el_pipe.context_weight = 0
el_pipe.prior_weight = 1
dev_acc_0_1, dev_acc_0_1_dict = _measure_accuracy(dev_data, el_pipe)
print("dev acc prior:", round(dev_acc_0_1, 3), [(x, round(y, 3)) for x, y in dev_acc_0_1_dict.items()])
train_acc_0_1, train_acc_0_1_dict = _measure_accuracy(train_data, el_pipe)
print("train acc prior:", round(train_acc_0_1, 3), [(x, round(y, 3)) for x, y in train_acc_0_1_dict.items()])
dev_acc_combo, dev_acc_combo_dict = _measure_accuracy(dev_data, el_pipe)
print("dev acc combo:", round(dev_acc_combo, 3), [(x, round(y, 3)) for x, y in dev_acc_combo_dict.items()])
# using only context
el_pipe.context_weight = 1
el_pipe.prior_weight = 0
dev_acc_1_0, dev_acc_1_0_dict = _measure_accuracy(dev_data, el_pipe)
print("dev acc context:", round(dev_acc_1_0, 3), [(x, round(y, 3)) for x, y in dev_acc_1_0_dict.items()])
train_acc_1_0, train_acc_1_0_dict = _measure_accuracy(train_data, el_pipe)
print("train acc context:", round(train_acc_1_0, 3), [(x, round(y, 3)) for x, y in train_acc_1_0_dict.items()])
dev_acc_context, dev_acc_1_0_dict = _measure_accuracy(dev_data, el_pipe)
print("dev acc context:", round(dev_acc_context, 3), [(x, round(y, 3)) for x, y in dev_acc_1_0_dict.items()])
print()
# reset for follow-up tests
@ -219,7 +216,7 @@ def run_pipeline():
run_el_toy_example(nlp=nlp_2)
print()
if test_nlp_io:
if to_write_nlp:
print()
print("STEP 9: testing NLP IO", datetime.datetime.now())
print()
@ -229,9 +226,10 @@ def run_pipeline():
print("reading from", NLP_2_DIR)
nlp_3 = spacy.load(NLP_2_DIR)
print()
print("running toy example with NLP 2")
run_el_toy_example(nlp=nlp_3)
if to_read_nlp:
print()
print("running toy example with NLP 2")
run_el_toy_example(nlp=nlp_3)
print()
print("STOP", datetime.datetime.now())
@ -270,6 +268,80 @@ def _measure_accuracy(data, el_pipe):
except Exception as e:
print("Error assessing accuracy", e)
acc, acc_by_label = calculate_acc(correct_by_label, incorrect_by_label)
return acc, acc_by_label
def _measure_baselines(data, kb):
random_correct_by_label = dict()
random_incorrect_by_label = dict()
oracle_correct_by_label = dict()
oracle_incorrect_by_label = dict()
prior_correct_by_label = dict()
prior_incorrect_by_label = dict()
docs = [d for d, g in data if len(d) > 0]
golds = [g for d, g in data if len(d) > 0]
for doc, gold in zip(docs, golds):
try:
correct_entries_per_article = dict()
for entity in gold.links:
start, end, gold_kb = entity
correct_entries_per_article[str(start) + "-" + str(end)] = gold_kb
for ent in doc.ents:
ent_label = ent.label_
start = ent.start_char
end = ent.end_char
gold_entity = correct_entries_per_article.get(str(start) + "-" + str(end), None)
# the gold annotations are not complete so we can't evaluate missing annotations as 'wrong'
if gold_entity is not None:
candidates = kb.get_candidates(ent.text)
oracle_candidate = ""
best_candidate = ""
random_candidate = ""
if candidates:
scores = list()
for c in candidates:
scores.append(c.prior_prob)
if c.entity_ == gold_entity:
oracle_candidate = c.entity_
best_index = scores.index(max(scores))
best_candidate = candidates[best_index].entity_
random_candidate = random.choice(candidates).entity_
if gold_entity == best_candidate:
prior_correct_by_label[ent_label] = prior_correct_by_label.get(ent_label, 0) + 1
else:
prior_incorrect_by_label[ent_label] = prior_incorrect_by_label.get(ent_label, 0) + 1
if gold_entity == random_candidate:
random_correct_by_label[ent_label] = random_correct_by_label.get(ent_label, 0) + 1
else:
random_incorrect_by_label[ent_label] = random_incorrect_by_label.get(ent_label, 0) + 1
if gold_entity == oracle_candidate:
oracle_correct_by_label[ent_label] = oracle_correct_by_label.get(ent_label, 0) + 1
else:
oracle_incorrect_by_label[ent_label] = oracle_incorrect_by_label.get(ent_label, 0) + 1
except Exception as e:
print("Error assessing accuracy", e)
acc_prior, acc_prior_by_label = calculate_acc(prior_correct_by_label, prior_incorrect_by_label)
acc_random, acc_random_by_label = calculate_acc(random_correct_by_label, random_incorrect_by_label)
acc_oracle, acc_oracle_by_label = calculate_acc(oracle_correct_by_label, oracle_incorrect_by_label)
return acc_random, acc_random_by_label, acc_prior, acc_prior_by_label, acc_oracle, acc_oracle_by_label
def calculate_acc(correct_by_label, incorrect_by_label):
acc_by_label = dict()
total_correct = 0
total_incorrect = 0
@ -303,18 +375,25 @@ def run_el_toy_example(nlp):
"The main character in Doug's novel is the man Arthur Dent, " \
"but Douglas doesn't write about George Washington or Homer Simpson."
doc = nlp(text)
print(text)
for ent in doc.ents:
print("ent", ent.text, ent.label_, ent.kb_id_)
print()
# Q4426480 is her husband, Q3568763 her tutor
text = "Ada Lovelace was the countess of Lovelace. She is known for her programming work on the analytical engine."\
"Ada Lovelace loved her husband William King dearly. " \
"Ada Lovelace was tutored by her favorite physics tutor William King."
# Q4426480 is her husband
text = "Ada Lovelace was the countess of Lovelace. She is known for her programming work on the analytical engine. "\
"She loved her husband William King dearly. "
doc = nlp(text)
print(text)
for ent in doc.ents:
print("ent", ent.text, ent.label_, ent.kb_id_)
print()
# Q3568763 is her tutor
text = "Ada Lovelace was the countess of Lovelace. She is known for her programming work on the analytical engine. "\
"She was tutored by her favorite physics tutor William King."
doc = nlp(text)
print(text)
for ent in doc.ents:
print("ent", ent.text, ent.label_, ent.kb_id_)