filter training data beforehand (+black formatting)

This commit is contained in:
svlandeg 2019-07-18 10:22:24 +02:00
parent d833d4c358
commit ec55d2fccd
5 changed files with 294 additions and 160 deletions

View File

@ -18,6 +18,10 @@ Gold-standard entities are stored in one file in standoff format (by character o
ENTITY_FILE = "gold_entities.csv"
def now():
return datetime.datetime.now()
def create_training(wikipedia_input, entity_def_input, training_output):
wp_to_id = kb_creator.get_entity_to_id(entity_def_input)
_process_wikipedia_texts(wikipedia_input, wp_to_id, training_output, limit=None)
@ -54,12 +58,7 @@ def _process_wikipedia_texts(wikipedia_input, wp_to_id, training_output, limit=N
reading_revision = False
while line and (not limit or cnt < limit):
if cnt % 1000000 == 0:
print(
datetime.datetime.now(),
"processed",
cnt,
"lines of Wikipedia dump",
)
print(now(), "processed", cnt, "lines of Wikipedia dump")
clean_line = line.strip().decode("utf-8")
if clean_line == "<revision>":
@ -328,8 +327,9 @@ def is_dev(article_id):
def read_training(nlp, training_dir, dev, limit, kb=None):
""" This method provides training examples that correspond to the entity annotations found by the nlp object.
When kb is provided, it will include also negative training examples by using the candidate generator.
When kb=None, it will only include positive training examples."""
When kb is provided (for training), it will include negative training examples by using the candidate generator,
and it will only keep positive training examples that can be found in the KB.
When kb=None (for testing), it will include all positive examples only."""
entityfile_loc = training_dir / ENTITY_FILE
data = []
@ -402,12 +402,11 @@ def read_training(nlp, training_dir, dev, limit, kb=None):
gold_end = int(end) - found_ent.sent.start_char
# add both pos and neg examples (in random order)
# this will exclude examples not in the KB
if kb:
gold_entities = {}
candidates = kb.get_candidates(alias)
candidate_ids = [c.entity_ for c in candidates]
# add positive example in case the KB doesn't have it
candidate_ids.append(wd_id)
random.shuffle(candidate_ids)
for kb_id in candidate_ids:
entry = (gold_start, gold_end, kb_id)
@ -415,6 +414,7 @@ def read_training(nlp, training_dir, dev, limit, kb=None):
gold_entities[entry] = 0.0
else:
gold_entities[entry] = 1.0
# keep all positive examples
else:
entry = (gold_start, gold_end, wd_id)
gold_entities = {entry: 1.0}

View File

@ -14,22 +14,97 @@ Write these results to file for downstream KB and training data generation.
map_alias_to_link = dict()
# these will/should be matched ignoring case
wiki_namespaces = ["b", "betawikiversity", "Book", "c", "Category", "Commons",
"d", "dbdump", "download", "Draft", "Education", "Foundation",
"Gadget", "Gadget definition", "gerrit", "File", "Help", "Image", "Incubator",
"m", "mail", "mailarchive", "media", "MediaWiki", "MediaWiki talk", "Mediawikiwiki",
"MediaZilla", "Meta", "Metawikipedia", "Module",
"mw", "n", "nost", "oldwikisource", "outreach", "outreachwiki", "otrs", "OTRSwiki",
"Portal", "phab", "Phabricator", "Project", "q", "quality", "rev",
"s", "spcom", "Special", "species", "Strategy", "sulutil", "svn",
"Talk", "Template", "Template talk", "Testwiki", "ticket", "TimedText", "Toollabs", "tools",
"tswiki", "User", "User talk", "v", "voy",
"w", "Wikibooks", "Wikidata", "wikiHow", "Wikinvest", "wikilivres", "Wikimedia", "Wikinews",
"Wikipedia", "Wikipedia talk", "Wikiquote", "Wikisource", "Wikispecies", "Wikitech",
"Wikiversity", "Wikivoyage", "wikt", "wiktionary", "wmf", "wmania", "WP"]
wiki_namespaces = [
"b",
"betawikiversity",
"Book",
"c",
"Category",
"Commons",
"d",
"dbdump",
"download",
"Draft",
"Education",
"Foundation",
"Gadget",
"Gadget definition",
"gerrit",
"File",
"Help",
"Image",
"Incubator",
"m",
"mail",
"mailarchive",
"media",
"MediaWiki",
"MediaWiki talk",
"Mediawikiwiki",
"MediaZilla",
"Meta",
"Metawikipedia",
"Module",
"mw",
"n",
"nost",
"oldwikisource",
"outreach",
"outreachwiki",
"otrs",
"OTRSwiki",
"Portal",
"phab",
"Phabricator",
"Project",
"q",
"quality",
"rev",
"s",
"spcom",
"Special",
"species",
"Strategy",
"sulutil",
"svn",
"Talk",
"Template",
"Template talk",
"Testwiki",
"ticket",
"TimedText",
"Toollabs",
"tools",
"tswiki",
"User",
"User talk",
"v",
"voy",
"w",
"Wikibooks",
"Wikidata",
"wikiHow",
"Wikinvest",
"wikilivres",
"Wikimedia",
"Wikinews",
"Wikipedia",
"Wikipedia talk",
"Wikiquote",
"Wikisource",
"Wikispecies",
"Wikitech",
"Wikiversity",
"Wikivoyage",
"wikt",
"wiktionary",
"wmf",
"wmania",
"WP",
]
# find the links
link_regex = re.compile(r'\[\[[^\[\]]*\]\]')
link_regex = re.compile(r"\[\[[^\[\]]*\]\]")
# match on interwiki links, e.g. `en:` or `:fr:`
ns_regex = r":?" + "[a-z][a-z]" + ":"
@ -41,18 +116,22 @@ for ns in wiki_namespaces:
ns_regex = re.compile(ns_regex, re.IGNORECASE)
def read_wikipedia_prior_probs(wikipedia_input, prior_prob_output):
def now():
return datetime.datetime.now()
def read_prior_probs(wikipedia_input, prior_prob_output):
"""
Read the XML wikipedia data and parse out intra-wiki links to estimate prior probabilities.
The full file takes about 2h to parse 1100M lines.
It works relatively fast because it runs line by line, irrelevant of which article the intrawiki is from.
"""
with bz2.open(wikipedia_input, mode='rb') as file:
with bz2.open(wikipedia_input, mode="rb") as file:
line = file.readline()
cnt = 0
while line:
if cnt % 5000000 == 0:
print(datetime.datetime.now(), "processed", cnt, "lines of Wikipedia dump")
print(now(), "processed", cnt, "lines of Wikipedia dump")
clean_line = line.strip().decode("utf-8")
aliases, entities, normalizations = get_wp_links(clean_line)
@ -64,10 +143,11 @@ def read_wikipedia_prior_probs(wikipedia_input, prior_prob_output):
cnt += 1
# write all aliases and their entities and count occurrences to file
with open(prior_prob_output, mode='w', encoding='utf8') as outputfile:
with open(prior_prob_output, mode="w", encoding="utf8") as outputfile:
outputfile.write("alias" + "|" + "count" + "|" + "entity" + "\n")
for alias, alias_dict in sorted(map_alias_to_link.items(), key=lambda x: x[0]):
for entity, count in sorted(alias_dict.items(), key=lambda x: x[1], reverse=True):
s_dict = sorted(alias_dict.items(), key=lambda x: x[1], reverse=True)
for entity, count in s_dict:
outputfile.write(alias + "|" + str(count) + "|" + entity + "\n")
@ -140,13 +220,13 @@ def write_entity_counts(prior_prob_input, count_output, to_print=False):
entity_to_count = dict()
total_count = 0
with open(prior_prob_input, mode='r', encoding='utf8') as prior_file:
with open(prior_prob_input, mode="r", encoding="utf8") as prior_file:
# skip header
prior_file.readline()
line = prior_file.readline()
while line:
splits = line.replace('\n', "").split(sep='|')
splits = line.replace("\n", "").split(sep="|")
# alias = splits[0]
count = int(splits[1])
entity = splits[2]
@ -158,7 +238,7 @@ def write_entity_counts(prior_prob_input, count_output, to_print=False):
line = prior_file.readline()
with open(count_output, mode='w', encoding='utf8') as entity_file:
with open(count_output, mode="w", encoding="utf8") as entity_file:
entity_file.write("entity" + "|" + "count" + "\n")
for entity, count in entity_to_count.items():
entity_file.write(entity + "|" + str(count) + "\n")
@ -171,12 +251,11 @@ def write_entity_counts(prior_prob_input, count_output, to_print=False):
def get_all_frequencies(count_input):
entity_to_count = dict()
with open(count_input, 'r', encoding='utf8') as csvfile:
csvreader = csv.reader(csvfile, delimiter='|')
with open(count_input, "r", encoding="utf8") as csvfile:
csvreader = csv.reader(csvfile, delimiter="|")
# skip header
next(csvreader)
for row in csvreader:
entity_to_count[row[0]] = int(row[1])
return entity_to_count

View File

@ -5,7 +5,8 @@ import random
import datetime
from pathlib import Path
from bin.wiki_entity_linking import training_set_creator, kb_creator, wikipedia_processor as wp
from bin.wiki_entity_linking import wikipedia_processor as wp
from bin.wiki_entity_linking import training_set_creator, kb_creator
from bin.wiki_entity_linking.kb_creator import DESC_WIDTH
import spacy
@ -17,23 +18,25 @@ Demonstrate how to build a knowledge base from WikiData and run an Entity Linkin
"""
ROOT_DIR = Path("C:/Users/Sofie/Documents/data/")
OUTPUT_DIR = ROOT_DIR / 'wikipedia'
TRAINING_DIR = OUTPUT_DIR / 'training_data_nel'
OUTPUT_DIR = ROOT_DIR / "wikipedia"
TRAINING_DIR = OUTPUT_DIR / "training_data_nel"
PRIOR_PROB = OUTPUT_DIR / 'prior_prob.csv'
ENTITY_COUNTS = OUTPUT_DIR / 'entity_freq.csv'
ENTITY_DEFS = OUTPUT_DIR / 'entity_defs.csv'
ENTITY_DESCR = OUTPUT_DIR / 'entity_descriptions.csv'
PRIOR_PROB = OUTPUT_DIR / "prior_prob.csv"
ENTITY_COUNTS = OUTPUT_DIR / "entity_freq.csv"
ENTITY_DEFS = OUTPUT_DIR / "entity_defs.csv"
ENTITY_DESCR = OUTPUT_DIR / "entity_descriptions.csv"
KB_FILE = OUTPUT_DIR / 'kb_1' / 'kb'
NLP_1_DIR = OUTPUT_DIR / 'nlp_1'
NLP_2_DIR = OUTPUT_DIR / 'nlp_2'
KB_FILE = OUTPUT_DIR / "kb_1" / "kb"
NLP_1_DIR = OUTPUT_DIR / "nlp_1"
NLP_2_DIR = OUTPUT_DIR / "nlp_2"
# get latest-all.json.bz2 from https://dumps.wikimedia.org/wikidatawiki/entities/
WIKIDATA_JSON = ROOT_DIR / 'wikidata' / 'wikidata-20190304-all.json.bz2'
WIKIDATA_JSON = ROOT_DIR / "wikidata" / "wikidata-20190304-all.json.bz2"
# get enwiki-latest-pages-articles-multistream.xml.bz2 from https://dumps.wikimedia.org/enwiki/latest/
ENWIKI_DUMP = ROOT_DIR / 'wikipedia' / 'enwiki-20190320-pages-articles-multistream.xml.bz2'
ENWIKI_DUMP = (
ROOT_DIR / "wikipedia" / "enwiki-20190320-pages-articles-multistream.xml.bz2"
)
# KB construction parameters
MAX_CANDIDATES = 10
@ -48,11 +51,15 @@ L2 = 1e-6
CONTEXT_WIDTH = 128
def now():
return datetime.datetime.now()
def run_pipeline():
# set the appropriate booleans to define which parts of the pipeline should be re(run)
print("START", datetime.datetime.now())
print("START", now())
print()
nlp_1 = spacy.load('en_core_web_lg')
nlp_1 = spacy.load("en_core_web_lg")
nlp_2 = None
kb_2 = None
@ -82,20 +89,21 @@ def run_pipeline():
# STEP 1 : create prior probabilities from WP (run only once)
if to_create_prior_probs:
print("STEP 1: to_create_prior_probs", datetime.datetime.now())
wp.read_wikipedia_prior_probs(wikipedia_input=ENWIKI_DUMP, prior_prob_output=PRIOR_PROB)
print("STEP 1: to_create_prior_probs", now())
wp.read_prior_probs(ENWIKI_DUMP, PRIOR_PROB)
print()
# STEP 2 : deduce entity frequencies from WP (run only once)
if to_create_entity_counts:
print("STEP 2: to_create_entity_counts", datetime.datetime.now())
wp.write_entity_counts(prior_prob_input=PRIOR_PROB, count_output=ENTITY_COUNTS, to_print=False)
print("STEP 2: to_create_entity_counts", now())
wp.write_entity_counts(PRIOR_PROB, ENTITY_COUNTS, to_print=False)
print()
# STEP 3 : create KB and write to file (run only once)
if to_create_kb:
print("STEP 3a: to_create_kb", datetime.datetime.now())
kb_1 = kb_creator.create_kb(nlp_1,
print("STEP 3a: to_create_kb", now())
kb_1 = kb_creator.create_kb(
nlp=nlp_1,
max_entities_per_alias=MAX_CANDIDATES,
min_entity_freq=MIN_ENTITY_FREQ,
min_occ=MIN_PAIR_OCC,
@ -103,19 +111,20 @@ def run_pipeline():
entity_descr_output=ENTITY_DESCR,
count_input=ENTITY_COUNTS,
prior_prob_input=PRIOR_PROB,
wikidata_input=WIKIDATA_JSON)
wikidata_input=WIKIDATA_JSON,
)
print("kb entities:", kb_1.get_size_entities())
print("kb aliases:", kb_1.get_size_aliases())
print()
print("STEP 3b: write KB and NLP", datetime.datetime.now())
print("STEP 3b: write KB and NLP", now())
kb_1.dump(KB_FILE)
nlp_1.to_disk(NLP_1_DIR)
print()
# STEP 4 : read KB back in from file
if to_read_kb:
print("STEP 4: to_read_kb", datetime.datetime.now())
print("STEP 4: to_read_kb", now())
nlp_2 = spacy.load(NLP_1_DIR)
kb_2 = KnowledgeBase(vocab=nlp_2.vocab, entity_vector_length=DESC_WIDTH)
kb_2.load_bulk(KB_FILE)
@ -130,20 +139,26 @@ def run_pipeline():
# STEP 5: create a training dataset from WP
if create_wp_training:
print("STEP 5: create training dataset", datetime.datetime.now())
training_set_creator.create_training(wikipedia_input=ENWIKI_DUMP,
print("STEP 5: create training dataset", now())
training_set_creator.create_training(
wikipedia_input=ENWIKI_DUMP,
entity_def_input=ENTITY_DEFS,
training_output=TRAINING_DIR)
training_output=TRAINING_DIR,
)
# STEP 6: create and train the entity linking pipe
if train_pipe:
print("STEP 6: training Entity Linking pipe", datetime.datetime.now())
print("STEP 6: training Entity Linking pipe", now())
type_to_int = {label: i for i, label in enumerate(nlp_2.entity.labels)}
print(" -analysing", len(type_to_int), "different entity types")
el_pipe = nlp_2.create_pipe(name='entity_linker',
config={"context_width": CONTEXT_WIDTH,
el_pipe = nlp_2.create_pipe(
name="entity_linker",
config={
"context_width": CONTEXT_WIDTH,
"pretrained_vectors": nlp_2.vocab.vectors.name,
"type_to_int": type_to_int})
"type_to_int": type_to_int,
},
)
el_pipe.set_kb(kb_2)
nlp_2.add_pipe(el_pipe, last=True)
@ -157,18 +172,22 @@ def run_pipeline():
train_limit = 5000
dev_limit = 5000
train_data = training_set_creator.read_training(nlp=nlp_2,
# for training, get pos & neg instances that correspond to entries in the kb
train_data = training_set_creator.read_training(
nlp=nlp_2,
training_dir=TRAINING_DIR,
dev=False,
limit=train_limit)
limit=train_limit,
kb=el_pipe.kb,
)
print("Training on", len(train_data), "articles")
print()
dev_data = training_set_creator.read_training(nlp=nlp_2,
training_dir=TRAINING_DIR,
dev=True,
limit=dev_limit)
# for testing, get all pos instances, whether or not they are in the kb
dev_data = training_set_creator.read_training(
nlp=nlp_2, training_dir=TRAINING_DIR, dev=True, limit=dev_limit, kb=None
)
print("Dev testing on", len(dev_data), "articles")
print()
@ -187,8 +206,8 @@ def run_pipeline():
try:
docs, golds = zip(*batch)
nlp_2.update(
docs,
golds,
docs=docs,
golds=golds,
sgd=optimizer,
drop=DROPOUT,
losses=losses,
@ -200,48 +219,61 @@ def run_pipeline():
if batchnr > 0:
el_pipe.cfg["context_weight"] = 1
el_pipe.cfg["prior_weight"] = 1
dev_acc_context, dev_acc_context_dict = _measure_accuracy(dev_data, el_pipe)
losses['entity_linker'] = losses['entity_linker'] / batchnr
print("Epoch, train loss", itn, round(losses['entity_linker'], 2),
" / dev acc avg", round(dev_acc_context, 3))
dev_acc_context, _ = _measure_acc(dev_data, el_pipe)
losses["entity_linker"] = losses["entity_linker"] / batchnr
print(
"Epoch, train loss",
itn,
round(losses["entity_linker"], 2),
" / dev acc avg",
round(dev_acc_context, 3),
)
# STEP 7: measure the performance of our trained pipe on an independent dev set
if len(dev_data) and measure_performance:
print()
print("STEP 7: performance measurement of Entity Linking pipe", datetime.datetime.now())
print("STEP 7: performance measurement of Entity Linking pipe", now())
print()
counts, acc_r, acc_r_label, acc_p, acc_p_label, acc_o, acc_o_label = _measure_baselines(dev_data, kb_2)
counts, acc_r, acc_r_d, acc_p, acc_p_d, acc_o, acc_o_d = _measure_baselines(
dev_data, kb_2
)
print("dev counts:", sorted(counts.items(), key=lambda x: x[0]))
print("dev acc oracle:", round(acc_o, 3), [(x, round(y, 3)) for x, y in acc_o_label.items()])
print("dev acc random:", round(acc_r, 3), [(x, round(y, 3)) for x, y in acc_r_label.items()])
print("dev acc prior:", round(acc_p, 3), [(x, round(y, 3)) for x, y in acc_p_label.items()])
oracle_by_label = [(x, round(y, 3)) for x, y in acc_o_d.items()]
print("dev acc oracle:", round(acc_o, 3), oracle_by_label)
random_by_label = [(x, round(y, 3)) for x, y in acc_r_d.items()]
print("dev acc random:", round(acc_r, 3), random_by_label)
prior_by_label = [(x, round(y, 3)) for x, y in acc_p_d.items()]
print("dev acc prior:", round(acc_p, 3), prior_by_label)
# using only context
el_pipe.cfg["context_weight"] = 1
el_pipe.cfg["prior_weight"] = 0
dev_acc_context, dev_acc_context_dict = _measure_accuracy(dev_data, el_pipe)
print("dev acc context avg:", round(dev_acc_context, 3),
[(x, round(y, 3)) for x, y in dev_acc_context_dict.items()])
dev_acc_context, dev_acc_cont_d = _measure_acc(dev_data, el_pipe)
context_by_label = [(x, round(y, 3)) for x, y in dev_acc_cont_d.items()]
print("dev acc context avg:", round(dev_acc_context, 3), context_by_label)
# measuring combined accuracy (prior + context)
el_pipe.cfg["context_weight"] = 1
el_pipe.cfg["prior_weight"] = 1
dev_acc_combo, dev_acc_combo_dict = _measure_accuracy(dev_data, el_pipe, error_analysis=False)
print("dev acc combo avg:", round(dev_acc_combo, 3),
[(x, round(y, 3)) for x, y in dev_acc_combo_dict.items()])
dev_acc_combo, dev_acc_combo_d = _measure_acc(dev_data, el_pipe)
combo_by_label = [(x, round(y, 3)) for x, y in dev_acc_combo_d.items()]
print("dev acc combo avg:", round(dev_acc_combo, 3), combo_by_label)
# STEP 8: apply the EL pipe on a toy example
if to_test_pipeline:
print()
print("STEP 8: applying Entity Linking to toy example", datetime.datetime.now())
print("STEP 8: applying Entity Linking to toy example", now())
print()
run_el_toy_example(nlp=nlp_2)
# STEP 9: write the NLP pipeline (including entity linker) to file
if to_write_nlp:
print()
print("STEP 9: testing NLP IO", datetime.datetime.now())
print("STEP 9: testing NLP IO", now())
print()
print("writing to", NLP_2_DIR)
nlp_2.to_disk(NLP_2_DIR)
@ -262,23 +294,26 @@ def run_pipeline():
el_pipe = nlp_3.get_pipe("entity_linker")
dev_limit = 5000
dev_data = training_set_creator.read_training(nlp=nlp_2,
dev_data = training_set_creator.read_training(
nlp=nlp_2,
training_dir=TRAINING_DIR,
dev=True,
limit=dev_limit)
limit=dev_limit,
kb=el_pipe.kb,
)
print("Dev testing from file on", len(dev_data), "articles")
print()
dev_acc_combo, dev_acc_combo_dict = _measure_accuracy(dev_data, el_pipe=el_pipe, error_analysis=False)
print("dev acc combo avg:", round(dev_acc_combo, 3),
[(x, round(y, 3)) for x, y in dev_acc_combo_dict.items()])
dev_acc_combo, dev_acc_combo_dict = _measure_acc(dev_data, el_pipe)
combo_by_label = [(x, round(y, 3)) for x, y in dev_acc_combo_dict.items()]
print("dev acc combo avg:", round(dev_acc_combo, 3), combo_by_label)
print()
print("STOP", datetime.datetime.now())
print("STOP", now())
def _measure_accuracy(data, el_pipe=None, error_analysis=False):
def _measure_acc(data, el_pipe=None, error_analysis=False):
# If the docs in the data require further processing with an entity linker, set el_pipe
correct_by_label = dict()
incorrect_by_label = dict()
@ -291,7 +326,9 @@ def _measure_accuracy(data, el_pipe=None, error_analysis=False):
for doc, gold in zip(docs, golds):
try:
correct_entries_per_article = dict()
for entity in gold.links:
for entity, value in gold.links.items():
# only evaluating on positive examples
if value:
start, end, gold_kb = entity
correct_entries_per_article[str(start) + "-" + str(end)] = gold_kb
@ -300,7 +337,8 @@ def _measure_accuracy(data, el_pipe=None, error_analysis=False):
pred_entity = ent.kb_id_
start = ent.start_char
end = ent.end_char
gold_entity = correct_entries_per_article.get(str(start) + "-" + str(end), None)
offset = str(start) + "-" + str(end)
gold_entity = correct_entries_per_article.get(offset, None)
# the gold annotations are not complete so we can't evaluate missing annotations as 'wrong'
if gold_entity is not None:
if gold_entity == pred_entity:
@ -311,7 +349,12 @@ def _measure_accuracy(data, el_pipe=None, error_analysis=False):
incorrect_by_label[ent_label] = incorrect + 1
if error_analysis:
print(ent.text, "in", doc)
print("Predicted", pred_entity, "should have been", gold_entity)
print(
"Predicted",
pred_entity,
"should have been",
gold_entity,
)
print()
except Exception as e:
@ -323,16 +366,16 @@ def _measure_accuracy(data, el_pipe=None, error_analysis=False):
def _measure_baselines(data, kb):
# Measure 3 performance baselines: random selection, prior probabilities, and 'oracle' prediction for upper bound
counts_by_label = dict()
counts_d = dict()
random_correct_by_label = dict()
random_incorrect_by_label = dict()
random_correct_d = dict()
random_incorrect_d = dict()
oracle_correct_by_label = dict()
oracle_incorrect_by_label = dict()
oracle_correct_d = dict()
oracle_incorrect_d = dict()
prior_correct_by_label = dict()
prior_incorrect_by_label = dict()
prior_correct_d = dict()
prior_incorrect_d = dict()
docs = [d for d, g in data if len(d) > 0]
golds = [g for d, g in data if len(d) > 0]
@ -345,14 +388,15 @@ def _measure_baselines(data, kb):
correct_entries_per_article[str(start) + "-" + str(end)] = gold_kb
for ent in doc.ents:
ent_label = ent.label_
label = ent.label_
start = ent.start_char
end = ent.end_char
gold_entity = correct_entries_per_article.get(str(start) + "-" + str(end), None)
offset = str(start) + "-" + str(end)
gold_entity = correct_entries_per_article.get(offset, None)
# the gold annotations are not complete so we can't evaluate missing annotations as 'wrong'
if gold_entity is not None:
counts_by_label[ent_label] = counts_by_label.get(ent_label, 0) + 1
counts_d[label] = counts_d.get(label, 0) + 1
candidates = kb.get_candidates(ent.text)
oracle_candidate = ""
best_candidate = ""
@ -370,28 +414,36 @@ def _measure_baselines(data, kb):
random_candidate = random.choice(candidates).entity_
if gold_entity == best_candidate:
prior_correct_by_label[ent_label] = prior_correct_by_label.get(ent_label, 0) + 1
prior_correct_d[label] = prior_correct_d.get(label, 0) + 1
else:
prior_incorrect_by_label[ent_label] = prior_incorrect_by_label.get(ent_label, 0) + 1
prior_incorrect_d[label] = prior_incorrect_d.get(label, 0) + 1
if gold_entity == random_candidate:
random_correct_by_label[ent_label] = random_correct_by_label.get(ent_label, 0) + 1
random_correct_d[label] = random_correct_d.get(label, 0) + 1
else:
random_incorrect_by_label[ent_label] = random_incorrect_by_label.get(ent_label, 0) + 1
random_incorrect_d[label] = random_incorrect_d.get(label, 0) + 1
if gold_entity == oracle_candidate:
oracle_correct_by_label[ent_label] = oracle_correct_by_label.get(ent_label, 0) + 1
oracle_correct_d[label] = oracle_correct_d.get(label, 0) + 1
else:
oracle_incorrect_by_label[ent_label] = oracle_incorrect_by_label.get(ent_label, 0) + 1
oracle_incorrect_d[label] = oracle_incorrect_d.get(label, 0) + 1
except Exception as e:
print("Error assessing accuracy", e)
acc_prior, acc_prior_by_label = calculate_acc(prior_correct_by_label, prior_incorrect_by_label)
acc_rand, acc_rand_by_label = calculate_acc(random_correct_by_label, random_incorrect_by_label)
acc_oracle, acc_oracle_by_label = calculate_acc(oracle_correct_by_label, oracle_incorrect_by_label)
acc_prior, acc_prior_d = calculate_acc(prior_correct_d, prior_incorrect_d)
acc_rand, acc_rand_d = calculate_acc(random_correct_d, random_incorrect_d)
acc_oracle, acc_oracle_d = calculate_acc(oracle_correct_d, oracle_incorrect_d)
return counts_by_label, acc_rand, acc_rand_by_label, acc_prior, acc_prior_by_label, acc_oracle, acc_oracle_by_label
return (
counts_d,
acc_rand,
acc_rand_d,
acc_prior,
acc_prior_d,
acc_oracle,
acc_oracle_d,
)
def calculate_acc(correct_by_label, incorrect_by_label):
@ -422,15 +474,23 @@ def check_kb(kb):
print("generating candidates for " + mention + " :")
for c in candidates:
print(" ", c.prior_prob, c.alias_, "-->", c.entity_ + " (freq=" + str(c.entity_freq) + ")")
print(
" ",
c.prior_prob,
c.alias_,
"-->",
c.entity_ + " (freq=" + str(c.entity_freq) + ")",
)
print()
def run_el_toy_example(nlp):
text = "In The Hitchhiker's Guide to the Galaxy, written by Douglas Adams, " \
"Douglas reminds us to always bring our towel, even in China or Brazil. " \
"The main character in Doug's novel is the man Arthur Dent, " \
text = (
"In The Hitchhiker's Guide to the Galaxy, written by Douglas Adams, "
"Douglas reminds us to always bring our towel, even in China or Brazil. "
"The main character in Doug's novel is the man Arthur Dent, "
"but Douglas doesn't write about George Washington or Homer Simpson."
)
doc = nlp(text)
print(text)
for ent in doc.ents:

View File

@ -208,7 +208,7 @@ cdef class KnowledgeBase:
# Return an empty list if this entity is unknown in this KB
if entity_hash not in self._entry_index:
return []
return [0] * self.entity_vector_length
entry_index = self._entry_index[entity_hash]
return self._vectors_table[self._entries[entry_index].vector_index]

View File

@ -1151,9 +1151,11 @@ class EntityLinker(Pipe):
ents_by_offset = dict()
for ent in doc.ents:
ents_by_offset[str(ent.start_char) + "_" + str(ent.end_char)] = ent
for entity in gold.links:
start, end, gold_kb = entity
for entity, value in gold.links.items():
start, end, kb_id = entity
mention = doc.text[start:end]
entity_encoding = self.kb.get_vector(kb_id)
prior_prob = self.kb.get_prior_prob(kb_id, mention)
gold_ent = ents_by_offset[str(ent.start_char) + "_" + str(ent.end_char)]
assert gold_ent is not None
@ -1161,24 +1163,17 @@ class EntityLinker(Pipe):
if len(type_to_int) > 0:
type_vector[type_to_int[gold_ent.label_]] = 1
candidates = self.kb.get_candidates(mention)
random.shuffle(candidates)
for c in candidates:
kb_id = c.entity_
entity_encoding = c.entity_vector
# store data
entity_encodings.append(entity_encoding)
context_docs.append(doc)
type_vectors.append(type_vector)
if self.cfg.get("prior_weight", 1) > 0:
priors.append([c.prior_prob])
priors.append([prior_prob])
else:
priors.append([0])
if kb_id == gold_kb:
cats.append([1])
else:
cats.append([0])
cats.append([value])
if len(entity_encodings) > 0:
assert len(priors) == len(entity_encodings) == len(context_docs) == len(cats) == len(type_vectors)