performance per entity type

This commit is contained in:
svlandeg 2019-06-14 19:55:46 +02:00
parent b312f2d0e7
commit 81731907ba
5 changed files with 114 additions and 80 deletions

View File

@ -15,10 +15,10 @@ INPUT_DIM = 300 # dimension of pre-trained vectors
DESC_WIDTH = 64 DESC_WIDTH = 64
def create_kb(nlp, max_entities_per_alias, min_occ, def create_kb(nlp, max_entities_per_alias, min_entity_freq, min_occ,
entity_def_output, entity_descr_output, entity_def_output, entity_descr_output,
count_input, prior_prob_input, to_print=False): count_input, prior_prob_input, to_print=False):
""" Create the knowledge base from Wikidata entries """ # Create the knowledge base from Wikidata entries
kb = KnowledgeBase(vocab=nlp.vocab, entity_vector_length=DESC_WIDTH) kb = KnowledgeBase(vocab=nlp.vocab, entity_vector_length=DESC_WIDTH)
# disable this part of the pipeline when rerunning the KB generation from preprocessed files # disable this part of the pipeline when rerunning the KB generation from preprocessed files
@ -37,21 +37,26 @@ def create_kb(nlp, max_entities_per_alias, min_occ,
title_to_id = _get_entity_to_id(entity_def_output) title_to_id = _get_entity_to_id(entity_def_output)
id_to_descr = _get_id_to_description(entity_descr_output) id_to_descr = _get_id_to_description(entity_descr_output)
title_list = list(title_to_id.keys())
# TODO: remove this filter (just for quicker testing of code)
# title_list = title_list[0:342]
# title_to_id = {t: title_to_id[t] for t in title_list}
entity_list = [title_to_id[x] for x in title_list]
# Currently keeping entities from the KB where there is no description - putting a default void description
description_list = [id_to_descr.get(x, "No description defined") for x in entity_list]
print() print()
print(" * _get_entity_frequencies", datetime.datetime.now()) print(" * _get_entity_frequencies", datetime.datetime.now())
print() print()
entity_frequencies = wp.get_entity_frequencies(count_input=count_input, entities=title_list) entity_frequencies = wp.get_all_frequencies(count_input=count_input)
# filter the entities for in the KB by frequency, because there's just too much data otherwise
filtered_title_to_id = dict()
entity_list = list()
description_list = list()
frequency_list = list()
for title, entity in title_to_id.items():
freq = entity_frequencies.get(title, 0)
desc = id_to_descr.get(entity, None)
if desc and freq > min_entity_freq:
entity_list.append(entity)
description_list.append(desc)
frequency_list.append(freq)
filtered_title_to_id[title] = entity
print("Kept", len(filtered_title_to_id.keys()), "out of", len(title_to_id.keys()), "titles")
print() print()
print(" * train entity encoder", datetime.datetime.now()) print(" * train entity encoder", datetime.datetime.now())
@ -67,12 +72,12 @@ def create_kb(nlp, max_entities_per_alias, min_occ,
print() print()
print(" * adding", len(entity_list), "entities", datetime.datetime.now()) print(" * adding", len(entity_list), "entities", datetime.datetime.now())
kb.set_entities(entity_list=entity_list, prob_list=entity_frequencies, vector_list=embeddings) kb.set_entities(entity_list=entity_list, prob_list=frequency_list, vector_list=embeddings)
print() print()
print(" * adding aliases", datetime.datetime.now()) print(" * adding aliases", datetime.datetime.now())
print() print()
_add_aliases(kb, title_to_id=title_to_id, _add_aliases(kb, title_to_id=filtered_title_to_id,
max_entities_per_alias=max_entities_per_alias, min_occ=min_occ, max_entities_per_alias=max_entities_per_alias, min_occ=min_occ,
prior_prob_input=prior_prob_input) prior_prob_input=prior_prob_input)

View File

@ -21,7 +21,7 @@ ENTITY_FILE = "gold_entities.csv"
def create_training(entity_def_input, training_output): def create_training(entity_def_input, training_output):
wp_to_id = kb_creator._get_entity_to_id(entity_def_input) wp_to_id = kb_creator._get_entity_to_id(entity_def_input)
_process_wikipedia_texts(wp_to_id, training_output, limit=100000000) # TODO: full dataset 100000000 _process_wikipedia_texts(wp_to_id, training_output, limit=100000000)
def _process_wikipedia_texts(wp_to_id, training_output, limit=None): def _process_wikipedia_texts(wp_to_id, training_output, limit=None):

View File

@ -29,6 +29,7 @@ NLP_2_DIR = 'C:/Users/Sofie/Documents/data/wikipedia/nlp_2'
TRAINING_DIR = 'C:/Users/Sofie/Documents/data/wikipedia/training_data_nel/' TRAINING_DIR = 'C:/Users/Sofie/Documents/data/wikipedia/training_data_nel/'
MAX_CANDIDATES = 10 MAX_CANDIDATES = 10
MIN_ENTITY_FREQ = 200
MIN_PAIR_OCC = 5 MIN_PAIR_OCC = 5
DOC_SENT_CUTOFF = 2 DOC_SENT_CUTOFF = 2
EPOCHS = 10 EPOCHS = 10
@ -46,14 +47,14 @@ def run_pipeline():
# one-time methods to create KB and write to file # one-time methods to create KB and write to file
to_create_prior_probs = False to_create_prior_probs = False
to_create_entity_counts = False to_create_entity_counts = False
to_create_kb = False # TODO: entity_defs should also contain entities not in the KB to_create_kb = True
# read KB back in from file # read KB back in from file
to_read_kb = False to_read_kb = False
to_test_kb = False to_test_kb = False
# create training dataset # create training dataset
create_wp_training = True create_wp_training = False
# train the EL pipe # train the EL pipe
train_pipe = False train_pipe = False
@ -84,13 +85,14 @@ def run_pipeline():
if to_create_kb: if to_create_kb:
print("STEP 3a: to_create_kb", datetime.datetime.now()) print("STEP 3a: to_create_kb", datetime.datetime.now())
kb_1 = kb_creator.create_kb(nlp_1, kb_1 = kb_creator.create_kb(nlp_1,
max_entities_per_alias=MAX_CANDIDATES, max_entities_per_alias=MAX_CANDIDATES,
min_occ=MIN_PAIR_OCC, min_entity_freq=MIN_ENTITY_FREQ,
entity_def_output=ENTITY_DEFS, min_occ=MIN_PAIR_OCC,
entity_descr_output=ENTITY_DESCR, entity_def_output=ENTITY_DEFS,
count_input=ENTITY_COUNTS, entity_descr_output=ENTITY_DESCR,
prior_prob_input=PRIOR_PROB, count_input=ENTITY_COUNTS,
to_print=False) prior_prob_input=PRIOR_PROB,
to_print=False)
print("kb entities:", kb_1.get_size_entities()) print("kb entities:", kb_1.get_size_entities())
print("kb aliases:", kb_1.get_size_aliases()) print("kb aliases:", kb_1.get_size_aliases())
print() print()
@ -112,7 +114,7 @@ def run_pipeline():
# test KB # test KB
if to_test_kb: if to_test_kb:
run_el.run_kb_toy_example(kb=kb_2) test_kb(kb_2)
print() print()
# STEP 5: create a training dataset from WP # STEP 5: create a training dataset from WP
@ -121,10 +123,18 @@ def run_pipeline():
training_set_creator.create_training(entity_def_input=ENTITY_DEFS, training_output=TRAINING_DIR) training_set_creator.create_training(entity_def_input=ENTITY_DEFS, training_output=TRAINING_DIR)
# STEP 6: create the entity linking pipe # STEP 6: create the entity linking pipe
el_pipe = nlp_2.create_pipe(name='entity_linker', config={"doc_cutoff": DOC_SENT_CUTOFF})
el_pipe.set_kb(kb_2)
nlp_2.add_pipe(el_pipe, last=True)
other_pipes = [pipe for pipe in nlp_2.pipe_names if pipe != "entity_linker"]
with nlp_2.disable_pipes(*other_pipes): # only train Entity Linking
nlp_2.begin_training()
if train_pipe: if train_pipe:
print("STEP 6: training Entity Linking pipe", datetime.datetime.now()) print("STEP 6: training Entity Linking pipe", datetime.datetime.now())
train_limit = 50 train_limit = 10
dev_limit = 10 dev_limit = 2
print("Training on", train_limit, "articles") print("Training on", train_limit, "articles")
print("Dev testing on", dev_limit, "articles") print("Dev testing on", dev_limit, "articles")
print() print()
@ -141,14 +151,6 @@ def run_pipeline():
limit=dev_limit, limit=dev_limit,
to_print=False) to_print=False)
el_pipe = nlp_2.create_pipe(name='entity_linker', config={"doc_cutoff": DOC_SENT_CUTOFF})
el_pipe.set_kb(kb_2)
nlp_2.add_pipe(el_pipe, last=True)
other_pipes = [pipe for pipe in nlp_2.pipe_names if pipe != "entity_linker"]
with nlp_2.disable_pipes(*other_pipes): # only train Entity Linking
nlp_2.begin_training()
for itn in range(EPOCHS): for itn in range(EPOCHS):
random.shuffle(train_data) random.shuffle(train_data)
losses = {} losses = {}
@ -180,30 +182,32 @@ def run_pipeline():
# print(" measuring accuracy 1-1") # print(" measuring accuracy 1-1")
el_pipe.context_weight = 1 el_pipe.context_weight = 1
el_pipe.prior_weight = 1 el_pipe.prior_weight = 1
dev_acc_1_1 = _measure_accuracy(dev_data, el_pipe) dev_acc_1_1, dev_acc_1_1_dict = _measure_accuracy(dev_data, el_pipe)
train_acc_1_1 = _measure_accuracy(train_data, el_pipe) print("dev acc combo:", round(dev_acc_1_1, 3), [(x, round(y, 3)) for x, y in dev_acc_1_1_dict.items()])
print("train/dev acc combo:", round(train_acc_1_1, 2), round(dev_acc_1_1, 2)) train_acc_1_1, train_acc_1_1_dict = _measure_accuracy(train_data, el_pipe)
print("train acc combo:", round(train_acc_1_1, 3), [(x, round(y, 3)) for x, y in train_acc_1_1_dict.items()])
# baseline using only prior probabilities # baseline using only prior probabilities
el_pipe.context_weight = 0 el_pipe.context_weight = 0
el_pipe.prior_weight = 1 el_pipe.prior_weight = 1
dev_acc_0_1 = _measure_accuracy(dev_data, el_pipe) dev_acc_0_1, dev_acc_0_1_dict = _measure_accuracy(dev_data, el_pipe)
train_acc_0_1 = _measure_accuracy(train_data, el_pipe) print("dev acc prior:", round(dev_acc_0_1, 3), [(x, round(y, 3)) for x, y in dev_acc_0_1_dict.items()])
print("train/dev acc prior:", round(train_acc_0_1, 2), round(dev_acc_0_1, 2)) train_acc_0_1, train_acc_0_1_dict = _measure_accuracy(train_data, el_pipe)
print("train acc prior:", round(train_acc_0_1, 3), [(x, round(y, 3)) for x, y in train_acc_0_1_dict.items()])
# using only context # using only context
el_pipe.context_weight = 1 el_pipe.context_weight = 1
el_pipe.prior_weight = 0 el_pipe.prior_weight = 0
dev_acc_1_0 = _measure_accuracy(dev_data, el_pipe) dev_acc_1_0, dev_acc_1_0_dict = _measure_accuracy(dev_data, el_pipe)
train_acc_1_0 = _measure_accuracy(train_data, el_pipe) print("dev acc context:", round(dev_acc_1_0, 3), [(x, round(y, 3)) for x, y in dev_acc_1_0_dict.items()])
print("train/dev acc context:", round(train_acc_1_0, 2), round(dev_acc_1_0, 2)) train_acc_1_0, train_acc_1_0_dict = _measure_accuracy(train_data, el_pipe)
print("train acc context:", round(train_acc_1_0, 3), [(x, round(y, 3)) for x, y in train_acc_1_0_dict.items()])
print() print()
# reset for follow-up tests # reset for follow-up tests
el_pipe.context_weight = 1 el_pipe.context_weight = 1
el_pipe.prior_weight = 1 el_pipe.prior_weight = 1
if to_test_pipeline: if to_test_pipeline:
print() print()
print("STEP 8: applying Entity Linking to toy example", datetime.datetime.now()) print("STEP 8: applying Entity Linking to toy example", datetime.datetime.now())
@ -230,8 +234,8 @@ def run_pipeline():
def _measure_accuracy(data, el_pipe): def _measure_accuracy(data, el_pipe):
correct = 0 correct_by_label = dict()
incorrect = 0 incorrect_by_label = dict()
docs = [d for d, g in data if len(d) > 0] docs = [d for d, g in data if len(d) > 0]
docs = el_pipe.pipe(docs) docs = el_pipe.pipe(docs)
@ -245,31 +249,53 @@ def _measure_accuracy(data, el_pipe):
correct_entries_per_article[str(start) + "-" + str(end)] = gold_kb correct_entries_per_article[str(start) + "-" + str(end)] = gold_kb
for ent in doc.ents: for ent in doc.ents:
if ent.label_ == "PERSON": # TODO: expand to other types ent_label = ent.label_
pred_entity = ent.kb_id_ pred_entity = ent.kb_id_
start = ent.start_char start = ent.start_char
end = ent.end_char end = ent.end_char
gold_entity = correct_entries_per_article.get(str(start) + "-" + str(end), None) gold_entity = correct_entries_per_article.get(str(start) + "-" + str(end), None)
# the gold annotations are not complete so we can't evaluate missing annotations as 'wrong' # the gold annotations are not complete so we can't evaluate missing annotations as 'wrong'
if gold_entity is not None: if gold_entity is not None:
if gold_entity == pred_entity: if gold_entity == pred_entity:
correct += 1 correct = correct_by_label.get(ent_label, 0)
else: correct_by_label[ent_label] = correct + 1
incorrect += 1 else:
incorrect = incorrect_by_label.get(ent_label, 0)
incorrect_by_label[ent_label] = incorrect + 1
except Exception as e: except Exception as e:
print("Error assessing accuracy", e) print("Error assessing accuracy", e)
if correct == incorrect == 0: acc_by_label = dict()
return 0 total_correct = 0
total_incorrect = 0
for label, correct in correct_by_label.items():
incorrect = incorrect_by_label.get(label, 0)
total_correct += correct
total_incorrect += incorrect
if correct == incorrect == 0:
acc_by_label[label] = 0
else:
acc_by_label[label] = correct / (correct + incorrect)
acc = 0
if not (total_correct == total_incorrect == 0):
acc = total_correct / (total_correct + total_incorrect)
return acc, acc_by_label
acc = correct / (correct + incorrect)
return acc def test_kb(kb):
for mention in ("Bush", "Douglas Adams", "Homer", "Brazil", "China"):
candidates = kb.get_candidates(mention)
print("generating candidates for " + mention + " :")
for c in candidates:
print(" ", c.prior_prob, c.alias_, "-->", c.entity_ + " (freq=" + str(c.entity_freq) + ")")
print()
def run_el_toy_example(nlp): def run_el_toy_example(nlp):
text = "In The Hitchhiker's Guide to the Galaxy, written by Douglas Adams, " \ text = "In The Hitchhiker's Guide to the Galaxy, written by Douglas Adams, " \
"Douglas reminds us to always bring our towel. " \ "Douglas reminds us to always bring our towel, even in China or Brazil. " \
"The main character in Doug's novel is the man Arthur Dent, " \ "The main character in Doug's novel is the man Arthur Dent, " \
"but Douglas doesn't write about George Washington or Homer Simpson." "but Douglas doesn't write about George Washington or Homer Simpson."
doc = nlp(text) doc = nlp(text)

View File

@ -1,7 +1,6 @@
# coding: utf-8 # coding: utf-8
from __future__ import unicode_literals from __future__ import unicode_literals
import re
import bz2 import bz2
import json import json
import datetime import datetime
@ -14,7 +13,7 @@ def read_wikidata_entities_json(limit=None, to_print=False):
""" Read the JSON wiki data and parse out the entities. Takes about 7u30 to parse 55M lines. """ """ Read the JSON wiki data and parse out the entities. Takes about 7u30 to parse 55M lines. """
lang = 'en' lang = 'en'
prop_filter = {'P31': {'Q5', 'Q15632617'}} # currently defined as OR: one property suffices to be selected # prop_filter = {'P31': {'Q5', 'Q15632617'}} # currently defined as OR: one property suffices to be selected
site_filter = 'enwiki' site_filter = 'enwiki'
title_to_id = dict() title_to_id = dict()
@ -41,18 +40,19 @@ def read_wikidata_entities_json(limit=None, to_print=False):
entry_type = obj["type"] entry_type = obj["type"]
if entry_type == "item": if entry_type == "item":
# filtering records on their properties # filtering records on their properties (currently disabled to get ALL data)
keep = False # keep = False
keep = True
claims = obj["claims"] claims = obj["claims"]
for prop, value_set in prop_filter.items(): # for prop, value_set in prop_filter.items():
claim_property = claims.get(prop, None) # claim_property = claims.get(prop, None)
if claim_property: # if claim_property:
for cp in claim_property: # for cp in claim_property:
cp_id = cp['mainsnak'].get('datavalue', {}).get('value', {}).get('id') # cp_id = cp['mainsnak'].get('datavalue', {}).get('value', {}).get('id')
cp_rank = cp['rank'] # cp_rank = cp['rank']
if cp_rank != "deprecated" and cp_id in value_set: # if cp_rank != "deprecated" and cp_id in value_set:
keep = True # keep = True
if keep: if keep:
unique_id = obj["id"] unique_id = obj["id"]
@ -70,6 +70,7 @@ def read_wikidata_entities_json(limit=None, to_print=False):
if to_print: if to_print:
print("prop:", prop, cp_values) print("prop:", prop, cp_values)
found_link = False
if parse_sitelinks: if parse_sitelinks:
site_value = obj["sitelinks"].get(site_filter, None) site_value = obj["sitelinks"].get(site_filter, None)
if site_value: if site_value:
@ -77,6 +78,7 @@ def read_wikidata_entities_json(limit=None, to_print=False):
if to_print: if to_print:
print(site_filter, ":", site) print(site_filter, ":", site)
title_to_id[site] = unique_id title_to_id[site] = unique_id
found_link = True
if parse_labels: if parse_labels:
labels = obj["labels"] labels = obj["labels"]
@ -86,7 +88,7 @@ def read_wikidata_entities_json(limit=None, to_print=False):
if to_print: if to_print:
print("label (" + lang + "):", lang_label["value"]) print("label (" + lang + "):", lang_label["value"])
if parse_descriptions: if found_link and parse_descriptions:
descriptions = obj["descriptions"] descriptions = obj["descriptions"]
if descriptions: if descriptions:
lang_descr = descriptions.get(lang, None) lang_descr = descriptions.get(lang, None)

View File

@ -175,7 +175,7 @@ def write_entity_counts(prior_prob_input, count_output, to_print=False):
print("Total count:", total_count) print("Total count:", total_count)
def get_entity_frequencies(count_input, entities): def get_all_frequencies(count_input):
entity_to_count = dict() entity_to_count = dict()
with open(count_input, 'r', encoding='utf8') as csvfile: with open(count_input, 'r', encoding='utf8') as csvfile:
csvreader = csv.reader(csvfile, delimiter='|') csvreader = csv.reader(csvfile, delimiter='|')
@ -184,4 +184,5 @@ def get_entity_frequencies(count_input, entities):
for row in csvreader: for row in csvreader:
entity_to_count[row[0]] = int(row[1]) entity_to_count[row[0]] = int(row[1])
return [entity_to_count.get(e, 0) for e in entities] return entity_to_count