mirror of
https://github.com/explosion/spaCy.git
synced 2024-11-14 05:37:03 +03:00
baseline performances: oracle KB, random and prior prob
This commit is contained in:
parent
24db1392b9
commit
6332af40de
|
@ -5,11 +5,8 @@ import os
|
||||||
import re
|
import re
|
||||||
import bz2
|
import bz2
|
||||||
import datetime
|
import datetime
|
||||||
from os import listdir
|
|
||||||
|
|
||||||
from examples.pipeline.wiki_entity_linking import run_el
|
|
||||||
from spacy.gold import GoldParse
|
from spacy.gold import GoldParse
|
||||||
from spacy.matcher import PhraseMatcher
|
|
||||||
from . import wikipedia_processor as wp, kb_creator
|
from . import wikipedia_processor as wp, kb_creator
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
@ -17,7 +14,7 @@ Process Wikipedia interlinks to generate a training dataset for the EL algorithm
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# ENTITY_FILE = "gold_entities.csv"
|
# ENTITY_FILE = "gold_entities.csv"
|
||||||
ENTITY_FILE = "gold_entities_100000.csv" # use this file for faster processing
|
ENTITY_FILE = "gold_entities_1000000.csv" # use this file for faster processing
|
||||||
|
|
||||||
|
|
||||||
def create_training(entity_def_input, training_output):
|
def create_training(entity_def_input, training_output):
|
||||||
|
@ -58,7 +55,6 @@ def _process_wikipedia_texts(wp_to_id, training_output, limit=None):
|
||||||
if cnt % 1000000 == 0:
|
if cnt % 1000000 == 0:
|
||||||
print(datetime.datetime.now(), "processed", cnt, "lines of Wikipedia dump")
|
print(datetime.datetime.now(), "processed", cnt, "lines of Wikipedia dump")
|
||||||
clean_line = line.strip().decode("utf-8")
|
clean_line = line.strip().decode("utf-8")
|
||||||
# print(clean_line)
|
|
||||||
|
|
||||||
if clean_line == "<revision>":
|
if clean_line == "<revision>":
|
||||||
reading_revision = True
|
reading_revision = True
|
||||||
|
@ -121,7 +117,6 @@ text_regex = re.compile(r'(?<=<text xml:space=\"preserve\">).*(?=</text)')
|
||||||
|
|
||||||
def _process_wp_text(wp_to_id, entityfile, article_id, article_title, article_text, training_output):
|
def _process_wp_text(wp_to_id, entityfile, article_id, article_title, article_text, training_output):
|
||||||
found_entities = False
|
found_entities = False
|
||||||
# print("Processing", article_id, article_title)
|
|
||||||
|
|
||||||
# ignore meta Wikipedia pages
|
# ignore meta Wikipedia pages
|
||||||
if article_title.startswith("Wikipedia:"):
|
if article_title.startswith("Wikipedia:"):
|
||||||
|
@ -134,13 +129,8 @@ def _process_wp_text(wp_to_id, entityfile, article_id, article_title, article_te
|
||||||
if text.startswith("#REDIRECT"):
|
if text.startswith("#REDIRECT"):
|
||||||
return
|
return
|
||||||
|
|
||||||
# print()
|
|
||||||
# print(text)
|
|
||||||
|
|
||||||
# get the raw text without markup etc, keeping only interwiki links
|
# get the raw text without markup etc, keeping only interwiki links
|
||||||
clean_text = _get_clean_wp_text(text)
|
clean_text = _get_clean_wp_text(text)
|
||||||
# print()
|
|
||||||
# print(clean_text)
|
|
||||||
|
|
||||||
# read the text char by char to get the right offsets of the interwiki links
|
# read the text char by char to get the right offsets of the interwiki links
|
||||||
final_text = ""
|
final_text = ""
|
||||||
|
@ -295,68 +285,62 @@ def is_dev(article_id):
|
||||||
return article_id.endswith("3")
|
return article_id.endswith("3")
|
||||||
|
|
||||||
|
|
||||||
def read_training_entities(training_output, dev, limit):
|
|
||||||
entityfile_loc = training_output + "/" + ENTITY_FILE
|
|
||||||
entries_per_article = dict()
|
|
||||||
article_ids = set()
|
|
||||||
|
|
||||||
with open(entityfile_loc, mode='r', encoding='utf8') as file:
|
|
||||||
for line in file:
|
|
||||||
if not limit or len(article_ids) < limit:
|
|
||||||
fields = line.replace('\n', "").split(sep='|')
|
|
||||||
article_id = fields[0]
|
|
||||||
if dev == is_dev(article_id) and article_id != "article_id":
|
|
||||||
article_ids.add(article_id)
|
|
||||||
|
|
||||||
alias = fields[1]
|
|
||||||
wp_title = fields[2]
|
|
||||||
start = fields[3]
|
|
||||||
end = fields[4]
|
|
||||||
|
|
||||||
entries_by_offset = entries_per_article.get(article_id, dict())
|
|
||||||
entries_by_offset[start + "-" + end] = (alias, wp_title)
|
|
||||||
|
|
||||||
entries_per_article[article_id] = entries_by_offset
|
|
||||||
|
|
||||||
return entries_per_article
|
|
||||||
|
|
||||||
|
|
||||||
def read_training(nlp, training_dir, dev, limit):
|
def read_training(nlp, training_dir, dev, limit):
|
||||||
# This method provides training examples that correspond to the entity annotations found by the nlp object
|
# This method provides training examples that correspond to the entity annotations found by the nlp object
|
||||||
|
|
||||||
print("reading training entities")
|
entityfile_loc = training_dir + "/" + ENTITY_FILE
|
||||||
entries_per_article = read_training_entities(training_output=training_dir, dev=dev, limit=limit)
|
|
||||||
print("done reading training entities")
|
|
||||||
|
|
||||||
data = []
|
data = []
|
||||||
for article_id, entries_by_offset in entries_per_article.items():
|
|
||||||
file_name = article_id + ".txt"
|
|
||||||
try:
|
|
||||||
# parse the article text
|
|
||||||
with open(os.path.join(training_dir, file_name), mode="r", encoding='utf8') as file:
|
|
||||||
text = file.read()
|
|
||||||
article_doc = nlp(text)
|
|
||||||
|
|
||||||
gold_entities = list()
|
# we assume the data is written sequentially
|
||||||
for ent in article_doc.ents:
|
current_article_id = None
|
||||||
start = ent.start_char
|
current_doc = None
|
||||||
end = ent.end_char
|
gold_entities = list()
|
||||||
|
ents_by_offset = dict()
|
||||||
|
skip_articles = set()
|
||||||
|
total_entities = 0
|
||||||
|
|
||||||
entity_tuple = entries_by_offset.get(str(start) + "-" + str(end), None)
|
with open(entityfile_loc, mode='r', encoding='utf8') as file:
|
||||||
if entity_tuple:
|
for line in file:
|
||||||
alias, wp_title = entity_tuple
|
if not limit or len(data) < limit:
|
||||||
if ent.text != alias:
|
if len(data) > 0 and len(data) % 50 == 0:
|
||||||
print("Non-matching entity in", article_id, start, end)
|
print("Read", total_entities, "entities in", len(data), "articles")
|
||||||
else:
|
fields = line.replace('\n', "").split(sep='|')
|
||||||
gold_entities.append((start, end, wp_title))
|
article_id = fields[0]
|
||||||
|
alias = fields[1]
|
||||||
|
wp_title = fields[2]
|
||||||
|
start = fields[3]
|
||||||
|
end = fields[4]
|
||||||
|
|
||||||
if gold_entities:
|
if dev == is_dev(article_id) and article_id != "article_id" and article_id not in skip_articles:
|
||||||
gold = GoldParse(doc=article_doc, links=gold_entities)
|
if not current_doc or (current_article_id != article_id):
|
||||||
data.append((article_doc, gold))
|
# store the data from the previous article
|
||||||
|
if gold_entities and current_doc:
|
||||||
|
gold = GoldParse(doc=current_doc, links=gold_entities)
|
||||||
|
data.append((current_doc, gold))
|
||||||
|
total_entities += len(gold_entities)
|
||||||
|
|
||||||
except Exception as e:
|
# parse the new article text
|
||||||
print("Problem parsing article", article_id)
|
file_name = article_id + ".txt"
|
||||||
print(e)
|
try:
|
||||||
raise e
|
with open(os.path.join(training_dir, file_name), mode="r", encoding='utf8') as f:
|
||||||
|
text = f.read()
|
||||||
|
current_doc = nlp(text)
|
||||||
|
for ent in current_doc.ents:
|
||||||
|
ents_by_offset[str(ent.start_char) + "_" + str(ent.end_char)] = ent.text
|
||||||
|
except Exception as e:
|
||||||
|
print("Problem parsing article", article_id, e)
|
||||||
|
|
||||||
|
current_article_id = article_id
|
||||||
|
gold_entities = list()
|
||||||
|
|
||||||
|
# repeat checking this condition in case an exception was thrown
|
||||||
|
if current_doc and (current_article_id == article_id):
|
||||||
|
found_ent = ents_by_offset.get(start + "_" + end, None)
|
||||||
|
if found_ent:
|
||||||
|
if found_ent != alias:
|
||||||
|
skip_articles.add(current_article_id)
|
||||||
|
else:
|
||||||
|
gold_entities.append((int(start), int(end), wp_title))
|
||||||
|
|
||||||
|
print("Read", total_entities, "entities in", len(data), "articles")
|
||||||
return data
|
return data
|
||||||
|
|
|
@ -64,7 +64,8 @@ def run_pipeline():
|
||||||
to_test_pipeline = True
|
to_test_pipeline = True
|
||||||
|
|
||||||
# write the NLP object, read back in and test again
|
# write the NLP object, read back in and test again
|
||||||
test_nlp_io = False
|
to_write_nlp = True
|
||||||
|
to_read_nlp = True
|
||||||
|
|
||||||
# STEP 1 : create prior probabilities from WP
|
# STEP 1 : create prior probabilities from WP
|
||||||
# run only once !
|
# run only once !
|
||||||
|
@ -133,7 +134,7 @@ def run_pipeline():
|
||||||
|
|
||||||
if train_pipe:
|
if train_pipe:
|
||||||
print("STEP 6: training Entity Linking pipe", datetime.datetime.now())
|
print("STEP 6: training Entity Linking pipe", datetime.datetime.now())
|
||||||
train_limit = 10
|
train_limit = 5
|
||||||
dev_limit = 2
|
dev_limit = 2
|
||||||
|
|
||||||
train_data = training_set_creator.read_training(nlp=nlp_2,
|
train_data = training_set_creator.read_training(nlp=nlp_2,
|
||||||
|
@ -166,46 +167,42 @@ def run_pipeline():
|
||||||
)
|
)
|
||||||
batchnr += 1
|
batchnr += 1
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print("Error updating batch", e)
|
print("Error updating batch:", e)
|
||||||
|
raise(e)
|
||||||
|
|
||||||
losses['entity_linker'] = losses['entity_linker'] / batchnr
|
if batchnr > 0:
|
||||||
print("Epoch, train loss", itn, round(losses['entity_linker'], 2))
|
losses['entity_linker'] = losses['entity_linker'] / batchnr
|
||||||
|
print("Epoch, train loss", itn, round(losses['entity_linker'], 2))
|
||||||
|
|
||||||
dev_data = training_set_creator.read_training(nlp=nlp_2,
|
dev_data = training_set_creator.read_training(nlp=nlp_2,
|
||||||
training_dir=TRAINING_DIR,
|
training_dir=TRAINING_DIR,
|
||||||
dev=True,
|
dev=True,
|
||||||
limit=dev_limit)
|
limit=dev_limit)
|
||||||
print("Dev testing on", len(dev_data), "articles")
|
|
||||||
print()
|
print()
|
||||||
|
print("Dev testing on", len(dev_data), "articles")
|
||||||
|
|
||||||
if len(dev_data) and measure_performance:
|
if len(dev_data) and measure_performance:
|
||||||
print()
|
print()
|
||||||
print("STEP 7: performance measurement of Entity Linking pipe", datetime.datetime.now())
|
print("STEP 7: performance measurement of Entity Linking pipe", datetime.datetime.now())
|
||||||
print()
|
print()
|
||||||
|
|
||||||
|
acc_random, acc_random_by_label, acc_prior, acc_prior_by_label, acc_oracle, acc_oracle_by_label = _measure_baselines(dev_data, kb_2)
|
||||||
|
print("dev acc oracle:", round(acc_oracle, 3), [(x, round(y, 3)) for x, y in acc_oracle_by_label.items()])
|
||||||
|
print("dev acc random:", round(acc_random, 3), [(x, round(y, 3)) for x, y in acc_random_by_label.items()])
|
||||||
|
print("dev acc prior:", round(acc_prior, 3), [(x, round(y, 3)) for x, y in acc_prior_by_label.items()])
|
||||||
|
|
||||||
# print(" measuring accuracy 1-1")
|
# print(" measuring accuracy 1-1")
|
||||||
el_pipe.context_weight = 1
|
el_pipe.context_weight = 1
|
||||||
el_pipe.prior_weight = 1
|
el_pipe.prior_weight = 1
|
||||||
dev_acc_1_1, dev_acc_1_1_dict = _measure_accuracy(dev_data, el_pipe)
|
dev_acc_combo, dev_acc_combo_dict = _measure_accuracy(dev_data, el_pipe)
|
||||||
print("dev acc combo:", round(dev_acc_1_1, 3), [(x, round(y, 3)) for x, y in dev_acc_1_1_dict.items()])
|
print("dev acc combo:", round(dev_acc_combo, 3), [(x, round(y, 3)) for x, y in dev_acc_combo_dict.items()])
|
||||||
train_acc_1_1, train_acc_1_1_dict = _measure_accuracy(train_data, el_pipe)
|
|
||||||
print("train acc combo:", round(train_acc_1_1, 3), [(x, round(y, 3)) for x, y in train_acc_1_1_dict.items()])
|
|
||||||
|
|
||||||
# baseline using only prior probabilities
|
|
||||||
el_pipe.context_weight = 0
|
|
||||||
el_pipe.prior_weight = 1
|
|
||||||
dev_acc_0_1, dev_acc_0_1_dict = _measure_accuracy(dev_data, el_pipe)
|
|
||||||
print("dev acc prior:", round(dev_acc_0_1, 3), [(x, round(y, 3)) for x, y in dev_acc_0_1_dict.items()])
|
|
||||||
train_acc_0_1, train_acc_0_1_dict = _measure_accuracy(train_data, el_pipe)
|
|
||||||
print("train acc prior:", round(train_acc_0_1, 3), [(x, round(y, 3)) for x, y in train_acc_0_1_dict.items()])
|
|
||||||
|
|
||||||
# using only context
|
# using only context
|
||||||
el_pipe.context_weight = 1
|
el_pipe.context_weight = 1
|
||||||
el_pipe.prior_weight = 0
|
el_pipe.prior_weight = 0
|
||||||
dev_acc_1_0, dev_acc_1_0_dict = _measure_accuracy(dev_data, el_pipe)
|
dev_acc_context, dev_acc_1_0_dict = _measure_accuracy(dev_data, el_pipe)
|
||||||
print("dev acc context:", round(dev_acc_1_0, 3), [(x, round(y, 3)) for x, y in dev_acc_1_0_dict.items()])
|
print("dev acc context:", round(dev_acc_context, 3), [(x, round(y, 3)) for x, y in dev_acc_1_0_dict.items()])
|
||||||
train_acc_1_0, train_acc_1_0_dict = _measure_accuracy(train_data, el_pipe)
|
|
||||||
print("train acc context:", round(train_acc_1_0, 3), [(x, round(y, 3)) for x, y in train_acc_1_0_dict.items()])
|
|
||||||
print()
|
print()
|
||||||
|
|
||||||
# reset for follow-up tests
|
# reset for follow-up tests
|
||||||
|
@ -219,7 +216,7 @@ def run_pipeline():
|
||||||
run_el_toy_example(nlp=nlp_2)
|
run_el_toy_example(nlp=nlp_2)
|
||||||
print()
|
print()
|
||||||
|
|
||||||
if test_nlp_io:
|
if to_write_nlp:
|
||||||
print()
|
print()
|
||||||
print("STEP 9: testing NLP IO", datetime.datetime.now())
|
print("STEP 9: testing NLP IO", datetime.datetime.now())
|
||||||
print()
|
print()
|
||||||
|
@ -229,9 +226,10 @@ def run_pipeline():
|
||||||
print("reading from", NLP_2_DIR)
|
print("reading from", NLP_2_DIR)
|
||||||
nlp_3 = spacy.load(NLP_2_DIR)
|
nlp_3 = spacy.load(NLP_2_DIR)
|
||||||
|
|
||||||
print()
|
if to_read_nlp:
|
||||||
print("running toy example with NLP 2")
|
print()
|
||||||
run_el_toy_example(nlp=nlp_3)
|
print("running toy example with NLP 2")
|
||||||
|
run_el_toy_example(nlp=nlp_3)
|
||||||
|
|
||||||
print()
|
print()
|
||||||
print("STOP", datetime.datetime.now())
|
print("STOP", datetime.datetime.now())
|
||||||
|
@ -270,6 +268,80 @@ def _measure_accuracy(data, el_pipe):
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print("Error assessing accuracy", e)
|
print("Error assessing accuracy", e)
|
||||||
|
|
||||||
|
acc, acc_by_label = calculate_acc(correct_by_label, incorrect_by_label)
|
||||||
|
return acc, acc_by_label
|
||||||
|
|
||||||
|
|
||||||
|
def _measure_baselines(data, kb):
|
||||||
|
random_correct_by_label = dict()
|
||||||
|
random_incorrect_by_label = dict()
|
||||||
|
|
||||||
|
oracle_correct_by_label = dict()
|
||||||
|
oracle_incorrect_by_label = dict()
|
||||||
|
|
||||||
|
prior_correct_by_label = dict()
|
||||||
|
prior_incorrect_by_label = dict()
|
||||||
|
|
||||||
|
docs = [d for d, g in data if len(d) > 0]
|
||||||
|
golds = [g for d, g in data if len(d) > 0]
|
||||||
|
|
||||||
|
for doc, gold in zip(docs, golds):
|
||||||
|
try:
|
||||||
|
correct_entries_per_article = dict()
|
||||||
|
for entity in gold.links:
|
||||||
|
start, end, gold_kb = entity
|
||||||
|
correct_entries_per_article[str(start) + "-" + str(end)] = gold_kb
|
||||||
|
|
||||||
|
for ent in doc.ents:
|
||||||
|
ent_label = ent.label_
|
||||||
|
start = ent.start_char
|
||||||
|
end = ent.end_char
|
||||||
|
gold_entity = correct_entries_per_article.get(str(start) + "-" + str(end), None)
|
||||||
|
|
||||||
|
# the gold annotations are not complete so we can't evaluate missing annotations as 'wrong'
|
||||||
|
if gold_entity is not None:
|
||||||
|
candidates = kb.get_candidates(ent.text)
|
||||||
|
oracle_candidate = ""
|
||||||
|
best_candidate = ""
|
||||||
|
random_candidate = ""
|
||||||
|
if candidates:
|
||||||
|
scores = list()
|
||||||
|
|
||||||
|
for c in candidates:
|
||||||
|
scores.append(c.prior_prob)
|
||||||
|
if c.entity_ == gold_entity:
|
||||||
|
oracle_candidate = c.entity_
|
||||||
|
|
||||||
|
best_index = scores.index(max(scores))
|
||||||
|
best_candidate = candidates[best_index].entity_
|
||||||
|
random_candidate = random.choice(candidates).entity_
|
||||||
|
|
||||||
|
if gold_entity == best_candidate:
|
||||||
|
prior_correct_by_label[ent_label] = prior_correct_by_label.get(ent_label, 0) + 1
|
||||||
|
else:
|
||||||
|
prior_incorrect_by_label[ent_label] = prior_incorrect_by_label.get(ent_label, 0) + 1
|
||||||
|
|
||||||
|
if gold_entity == random_candidate:
|
||||||
|
random_correct_by_label[ent_label] = random_correct_by_label.get(ent_label, 0) + 1
|
||||||
|
else:
|
||||||
|
random_incorrect_by_label[ent_label] = random_incorrect_by_label.get(ent_label, 0) + 1
|
||||||
|
|
||||||
|
if gold_entity == oracle_candidate:
|
||||||
|
oracle_correct_by_label[ent_label] = oracle_correct_by_label.get(ent_label, 0) + 1
|
||||||
|
else:
|
||||||
|
oracle_incorrect_by_label[ent_label] = oracle_incorrect_by_label.get(ent_label, 0) + 1
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print("Error assessing accuracy", e)
|
||||||
|
|
||||||
|
acc_prior, acc_prior_by_label = calculate_acc(prior_correct_by_label, prior_incorrect_by_label)
|
||||||
|
acc_random, acc_random_by_label = calculate_acc(random_correct_by_label, random_incorrect_by_label)
|
||||||
|
acc_oracle, acc_oracle_by_label = calculate_acc(oracle_correct_by_label, oracle_incorrect_by_label)
|
||||||
|
|
||||||
|
return acc_random, acc_random_by_label, acc_prior, acc_prior_by_label, acc_oracle, acc_oracle_by_label
|
||||||
|
|
||||||
|
|
||||||
|
def calculate_acc(correct_by_label, incorrect_by_label):
|
||||||
acc_by_label = dict()
|
acc_by_label = dict()
|
||||||
total_correct = 0
|
total_correct = 0
|
||||||
total_incorrect = 0
|
total_incorrect = 0
|
||||||
|
@ -303,18 +375,25 @@ def run_el_toy_example(nlp):
|
||||||
"The main character in Doug's novel is the man Arthur Dent, " \
|
"The main character in Doug's novel is the man Arthur Dent, " \
|
||||||
"but Douglas doesn't write about George Washington or Homer Simpson."
|
"but Douglas doesn't write about George Washington or Homer Simpson."
|
||||||
doc = nlp(text)
|
doc = nlp(text)
|
||||||
|
print(text)
|
||||||
for ent in doc.ents:
|
for ent in doc.ents:
|
||||||
print("ent", ent.text, ent.label_, ent.kb_id_)
|
print("ent", ent.text, ent.label_, ent.kb_id_)
|
||||||
|
|
||||||
print()
|
print()
|
||||||
|
|
||||||
# Q4426480 is her husband, Q3568763 her tutor
|
# Q4426480 is her husband
|
||||||
text = "Ada Lovelace was the countess of Lovelace. She is known for her programming work on the analytical engine."\
|
text = "Ada Lovelace was the countess of Lovelace. She is known for her programming work on the analytical engine. "\
|
||||||
"Ada Lovelace loved her husband William King dearly. " \
|
"She loved her husband William King dearly. "
|
||||||
"Ada Lovelace was tutored by her favorite physics tutor William King."
|
|
||||||
doc = nlp(text)
|
doc = nlp(text)
|
||||||
|
print(text)
|
||||||
|
for ent in doc.ents:
|
||||||
|
print("ent", ent.text, ent.label_, ent.kb_id_)
|
||||||
|
print()
|
||||||
|
|
||||||
|
# Q3568763 is her tutor
|
||||||
|
text = "Ada Lovelace was the countess of Lovelace. She is known for her programming work on the analytical engine. "\
|
||||||
|
"She was tutored by her favorite physics tutor William King."
|
||||||
|
doc = nlp(text)
|
||||||
|
print(text)
|
||||||
for ent in doc.ents:
|
for ent in doc.ents:
|
||||||
print("ent", ent.text, ent.label_, ent.kb_id_)
|
print("ent", ent.text, ent.label_, ent.kb_id_)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user