filter training data beforehand (+black formatting)

This commit is contained in:
svlandeg 2019-07-18 10:22:24 +02:00
parent d833d4c358
commit ec55d2fccd
5 changed files with 294 additions and 160 deletions

View File

@ -18,6 +18,10 @@ Gold-standard entities are stored in one file in standoff format (by character o
ENTITY_FILE = "gold_entities.csv" ENTITY_FILE = "gold_entities.csv"
def now():
return datetime.datetime.now()
def create_training(wikipedia_input, entity_def_input, training_output): def create_training(wikipedia_input, entity_def_input, training_output):
wp_to_id = kb_creator.get_entity_to_id(entity_def_input) wp_to_id = kb_creator.get_entity_to_id(entity_def_input)
_process_wikipedia_texts(wikipedia_input, wp_to_id, training_output, limit=None) _process_wikipedia_texts(wikipedia_input, wp_to_id, training_output, limit=None)
@ -54,12 +58,7 @@ def _process_wikipedia_texts(wikipedia_input, wp_to_id, training_output, limit=N
reading_revision = False reading_revision = False
while line and (not limit or cnt < limit): while line and (not limit or cnt < limit):
if cnt % 1000000 == 0: if cnt % 1000000 == 0:
print( print(now(), "processed", cnt, "lines of Wikipedia dump")
datetime.datetime.now(),
"processed",
cnt,
"lines of Wikipedia dump",
)
clean_line = line.strip().decode("utf-8") clean_line = line.strip().decode("utf-8")
if clean_line == "<revision>": if clean_line == "<revision>":
@ -328,8 +327,9 @@ def is_dev(article_id):
def read_training(nlp, training_dir, dev, limit, kb=None): def read_training(nlp, training_dir, dev, limit, kb=None):
""" This method provides training examples that correspond to the entity annotations found by the nlp object. """ This method provides training examples that correspond to the entity annotations found by the nlp object.
When kb is provided, it will include also negative training examples by using the candidate generator. When kb is provided (for training), it will include negative training examples by using the candidate generator,
When kb=None, it will only include positive training examples.""" and it will only keep positive training examples that can be found in the KB.
When kb=None (for testing), it will include all positive examples only."""
entityfile_loc = training_dir / ENTITY_FILE entityfile_loc = training_dir / ENTITY_FILE
data = [] data = []
@ -402,12 +402,11 @@ def read_training(nlp, training_dir, dev, limit, kb=None):
gold_end = int(end) - found_ent.sent.start_char gold_end = int(end) - found_ent.sent.start_char
# add both pos and neg examples (in random order) # add both pos and neg examples (in random order)
# this will exclude examples not in the KB
if kb: if kb:
gold_entities = {} gold_entities = {}
candidates = kb.get_candidates(alias) candidates = kb.get_candidates(alias)
candidate_ids = [c.entity_ for c in candidates] candidate_ids = [c.entity_ for c in candidates]
# add positive example in case the KB doesn't have it
candidate_ids.append(wd_id)
random.shuffle(candidate_ids) random.shuffle(candidate_ids)
for kb_id in candidate_ids: for kb_id in candidate_ids:
entry = (gold_start, gold_end, kb_id) entry = (gold_start, gold_end, kb_id)
@ -415,6 +414,7 @@ def read_training(nlp, training_dir, dev, limit, kb=None):
gold_entities[entry] = 0.0 gold_entities[entry] = 0.0
else: else:
gold_entities[entry] = 1.0 gold_entities[entry] = 1.0
# keep all positive examples
else: else:
entry = (gold_start, gold_end, wd_id) entry = (gold_start, gold_end, wd_id)
gold_entities = {entry: 1.0} gold_entities = {entry: 1.0}

View File

@ -14,22 +14,97 @@ Write these results to file for downstream KB and training data generation.
map_alias_to_link = dict() map_alias_to_link = dict()
# these will/should be matched ignoring case # these will/should be matched ignoring case
wiki_namespaces = ["b", "betawikiversity", "Book", "c", "Category", "Commons", wiki_namespaces = [
"d", "dbdump", "download", "Draft", "Education", "Foundation", "b",
"Gadget", "Gadget definition", "gerrit", "File", "Help", "Image", "Incubator", "betawikiversity",
"m", "mail", "mailarchive", "media", "MediaWiki", "MediaWiki talk", "Mediawikiwiki", "Book",
"MediaZilla", "Meta", "Metawikipedia", "Module", "c",
"mw", "n", "nost", "oldwikisource", "outreach", "outreachwiki", "otrs", "OTRSwiki", "Category",
"Portal", "phab", "Phabricator", "Project", "q", "quality", "rev", "Commons",
"s", "spcom", "Special", "species", "Strategy", "sulutil", "svn", "d",
"Talk", "Template", "Template talk", "Testwiki", "ticket", "TimedText", "Toollabs", "tools", "dbdump",
"tswiki", "User", "User talk", "v", "voy", "download",
"w", "Wikibooks", "Wikidata", "wikiHow", "Wikinvest", "wikilivres", "Wikimedia", "Wikinews", "Draft",
"Wikipedia", "Wikipedia talk", "Wikiquote", "Wikisource", "Wikispecies", "Wikitech", "Education",
"Wikiversity", "Wikivoyage", "wikt", "wiktionary", "wmf", "wmania", "WP"] "Foundation",
"Gadget",
"Gadget definition",
"gerrit",
"File",
"Help",
"Image",
"Incubator",
"m",
"mail",
"mailarchive",
"media",
"MediaWiki",
"MediaWiki talk",
"Mediawikiwiki",
"MediaZilla",
"Meta",
"Metawikipedia",
"Module",
"mw",
"n",
"nost",
"oldwikisource",
"outreach",
"outreachwiki",
"otrs",
"OTRSwiki",
"Portal",
"phab",
"Phabricator",
"Project",
"q",
"quality",
"rev",
"s",
"spcom",
"Special",
"species",
"Strategy",
"sulutil",
"svn",
"Talk",
"Template",
"Template talk",
"Testwiki",
"ticket",
"TimedText",
"Toollabs",
"tools",
"tswiki",
"User",
"User talk",
"v",
"voy",
"w",
"Wikibooks",
"Wikidata",
"wikiHow",
"Wikinvest",
"wikilivres",
"Wikimedia",
"Wikinews",
"Wikipedia",
"Wikipedia talk",
"Wikiquote",
"Wikisource",
"Wikispecies",
"Wikitech",
"Wikiversity",
"Wikivoyage",
"wikt",
"wiktionary",
"wmf",
"wmania",
"WP",
]
# find the links # find the links
link_regex = re.compile(r'\[\[[^\[\]]*\]\]') link_regex = re.compile(r"\[\[[^\[\]]*\]\]")
# match on interwiki links, e.g. `en:` or `:fr:` # match on interwiki links, e.g. `en:` or `:fr:`
ns_regex = r":?" + "[a-z][a-z]" + ":" ns_regex = r":?" + "[a-z][a-z]" + ":"
@ -41,18 +116,22 @@ for ns in wiki_namespaces:
ns_regex = re.compile(ns_regex, re.IGNORECASE) ns_regex = re.compile(ns_regex, re.IGNORECASE)
def read_wikipedia_prior_probs(wikipedia_input, prior_prob_output): def now():
return datetime.datetime.now()
def read_prior_probs(wikipedia_input, prior_prob_output):
""" """
Read the XML wikipedia data and parse out intra-wiki links to estimate prior probabilities. Read the XML wikipedia data and parse out intra-wiki links to estimate prior probabilities.
The full file takes about 2h to parse 1100M lines. The full file takes about 2h to parse 1100M lines.
It works relatively fast because it runs line by line, irrelevant of which article the intrawiki is from. It works relatively fast because it runs line by line, irrelevant of which article the intrawiki is from.
""" """
with bz2.open(wikipedia_input, mode='rb') as file: with bz2.open(wikipedia_input, mode="rb") as file:
line = file.readline() line = file.readline()
cnt = 0 cnt = 0
while line: while line:
if cnt % 5000000 == 0: if cnt % 5000000 == 0:
print(datetime.datetime.now(), "processed", cnt, "lines of Wikipedia dump") print(now(), "processed", cnt, "lines of Wikipedia dump")
clean_line = line.strip().decode("utf-8") clean_line = line.strip().decode("utf-8")
aliases, entities, normalizations = get_wp_links(clean_line) aliases, entities, normalizations = get_wp_links(clean_line)
@ -64,10 +143,11 @@ def read_wikipedia_prior_probs(wikipedia_input, prior_prob_output):
cnt += 1 cnt += 1
# write all aliases and their entities and count occurrences to file # write all aliases and their entities and count occurrences to file
with open(prior_prob_output, mode='w', encoding='utf8') as outputfile: with open(prior_prob_output, mode="w", encoding="utf8") as outputfile:
outputfile.write("alias" + "|" + "count" + "|" + "entity" + "\n") outputfile.write("alias" + "|" + "count" + "|" + "entity" + "\n")
for alias, alias_dict in sorted(map_alias_to_link.items(), key=lambda x: x[0]): for alias, alias_dict in sorted(map_alias_to_link.items(), key=lambda x: x[0]):
for entity, count in sorted(alias_dict.items(), key=lambda x: x[1], reverse=True): s_dict = sorted(alias_dict.items(), key=lambda x: x[1], reverse=True)
for entity, count in s_dict:
outputfile.write(alias + "|" + str(count) + "|" + entity + "\n") outputfile.write(alias + "|" + str(count) + "|" + entity + "\n")
@ -140,13 +220,13 @@ def write_entity_counts(prior_prob_input, count_output, to_print=False):
entity_to_count = dict() entity_to_count = dict()
total_count = 0 total_count = 0
with open(prior_prob_input, mode='r', encoding='utf8') as prior_file: with open(prior_prob_input, mode="r", encoding="utf8") as prior_file:
# skip header # skip header
prior_file.readline() prior_file.readline()
line = prior_file.readline() line = prior_file.readline()
while line: while line:
splits = line.replace('\n', "").split(sep='|') splits = line.replace("\n", "").split(sep="|")
# alias = splits[0] # alias = splits[0]
count = int(splits[1]) count = int(splits[1])
entity = splits[2] entity = splits[2]
@ -158,7 +238,7 @@ def write_entity_counts(prior_prob_input, count_output, to_print=False):
line = prior_file.readline() line = prior_file.readline()
with open(count_output, mode='w', encoding='utf8') as entity_file: with open(count_output, mode="w", encoding="utf8") as entity_file:
entity_file.write("entity" + "|" + "count" + "\n") entity_file.write("entity" + "|" + "count" + "\n")
for entity, count in entity_to_count.items(): for entity, count in entity_to_count.items():
entity_file.write(entity + "|" + str(count) + "\n") entity_file.write(entity + "|" + str(count) + "\n")
@ -171,12 +251,11 @@ def write_entity_counts(prior_prob_input, count_output, to_print=False):
def get_all_frequencies(count_input): def get_all_frequencies(count_input):
entity_to_count = dict() entity_to_count = dict()
with open(count_input, 'r', encoding='utf8') as csvfile: with open(count_input, "r", encoding="utf8") as csvfile:
csvreader = csv.reader(csvfile, delimiter='|') csvreader = csv.reader(csvfile, delimiter="|")
# skip header # skip header
next(csvreader) next(csvreader)
for row in csvreader: for row in csvreader:
entity_to_count[row[0]] = int(row[1]) entity_to_count[row[0]] = int(row[1])
return entity_to_count return entity_to_count

View File

@ -5,7 +5,8 @@ import random
import datetime import datetime
from pathlib import Path from pathlib import Path
from bin.wiki_entity_linking import training_set_creator, kb_creator, wikipedia_processor as wp from bin.wiki_entity_linking import wikipedia_processor as wp
from bin.wiki_entity_linking import training_set_creator, kb_creator
from bin.wiki_entity_linking.kb_creator import DESC_WIDTH from bin.wiki_entity_linking.kb_creator import DESC_WIDTH
import spacy import spacy
@ -17,23 +18,25 @@ Demonstrate how to build a knowledge base from WikiData and run an Entity Linkin
""" """
ROOT_DIR = Path("C:/Users/Sofie/Documents/data/") ROOT_DIR = Path("C:/Users/Sofie/Documents/data/")
OUTPUT_DIR = ROOT_DIR / 'wikipedia' OUTPUT_DIR = ROOT_DIR / "wikipedia"
TRAINING_DIR = OUTPUT_DIR / 'training_data_nel' TRAINING_DIR = OUTPUT_DIR / "training_data_nel"
PRIOR_PROB = OUTPUT_DIR / 'prior_prob.csv' PRIOR_PROB = OUTPUT_DIR / "prior_prob.csv"
ENTITY_COUNTS = OUTPUT_DIR / 'entity_freq.csv' ENTITY_COUNTS = OUTPUT_DIR / "entity_freq.csv"
ENTITY_DEFS = OUTPUT_DIR / 'entity_defs.csv' ENTITY_DEFS = OUTPUT_DIR / "entity_defs.csv"
ENTITY_DESCR = OUTPUT_DIR / 'entity_descriptions.csv' ENTITY_DESCR = OUTPUT_DIR / "entity_descriptions.csv"
KB_FILE = OUTPUT_DIR / 'kb_1' / 'kb' KB_FILE = OUTPUT_DIR / "kb_1" / "kb"
NLP_1_DIR = OUTPUT_DIR / 'nlp_1' NLP_1_DIR = OUTPUT_DIR / "nlp_1"
NLP_2_DIR = OUTPUT_DIR / 'nlp_2' NLP_2_DIR = OUTPUT_DIR / "nlp_2"
# get latest-all.json.bz2 from https://dumps.wikimedia.org/wikidatawiki/entities/ # get latest-all.json.bz2 from https://dumps.wikimedia.org/wikidatawiki/entities/
WIKIDATA_JSON = ROOT_DIR / 'wikidata' / 'wikidata-20190304-all.json.bz2' WIKIDATA_JSON = ROOT_DIR / "wikidata" / "wikidata-20190304-all.json.bz2"
# get enwiki-latest-pages-articles-multistream.xml.bz2 from https://dumps.wikimedia.org/enwiki/latest/ # get enwiki-latest-pages-articles-multistream.xml.bz2 from https://dumps.wikimedia.org/enwiki/latest/
ENWIKI_DUMP = ROOT_DIR / 'wikipedia' / 'enwiki-20190320-pages-articles-multistream.xml.bz2' ENWIKI_DUMP = (
ROOT_DIR / "wikipedia" / "enwiki-20190320-pages-articles-multistream.xml.bz2"
)
# KB construction parameters # KB construction parameters
MAX_CANDIDATES = 10 MAX_CANDIDATES = 10
@ -48,11 +51,15 @@ L2 = 1e-6
CONTEXT_WIDTH = 128 CONTEXT_WIDTH = 128
def now():
return datetime.datetime.now()
def run_pipeline(): def run_pipeline():
# set the appropriate booleans to define which parts of the pipeline should be re(run) # set the appropriate booleans to define which parts of the pipeline should be re(run)
print("START", datetime.datetime.now()) print("START", now())
print() print()
nlp_1 = spacy.load('en_core_web_lg') nlp_1 = spacy.load("en_core_web_lg")
nlp_2 = None nlp_2 = None
kb_2 = None kb_2 = None
@ -82,20 +89,21 @@ def run_pipeline():
# STEP 1 : create prior probabilities from WP (run only once) # STEP 1 : create prior probabilities from WP (run only once)
if to_create_prior_probs: if to_create_prior_probs:
print("STEP 1: to_create_prior_probs", datetime.datetime.now()) print("STEP 1: to_create_prior_probs", now())
wp.read_wikipedia_prior_probs(wikipedia_input=ENWIKI_DUMP, prior_prob_output=PRIOR_PROB) wp.read_prior_probs(ENWIKI_DUMP, PRIOR_PROB)
print() print()
# STEP 2 : deduce entity frequencies from WP (run only once) # STEP 2 : deduce entity frequencies from WP (run only once)
if to_create_entity_counts: if to_create_entity_counts:
print("STEP 2: to_create_entity_counts", datetime.datetime.now()) print("STEP 2: to_create_entity_counts", now())
wp.write_entity_counts(prior_prob_input=PRIOR_PROB, count_output=ENTITY_COUNTS, to_print=False) wp.write_entity_counts(PRIOR_PROB, ENTITY_COUNTS, to_print=False)
print() print()
# STEP 3 : create KB and write to file (run only once) # STEP 3 : create KB and write to file (run only once)
if to_create_kb: if to_create_kb:
print("STEP 3a: to_create_kb", datetime.datetime.now()) print("STEP 3a: to_create_kb", now())
kb_1 = kb_creator.create_kb(nlp_1, kb_1 = kb_creator.create_kb(
nlp=nlp_1,
max_entities_per_alias=MAX_CANDIDATES, max_entities_per_alias=MAX_CANDIDATES,
min_entity_freq=MIN_ENTITY_FREQ, min_entity_freq=MIN_ENTITY_FREQ,
min_occ=MIN_PAIR_OCC, min_occ=MIN_PAIR_OCC,
@ -103,19 +111,20 @@ def run_pipeline():
entity_descr_output=ENTITY_DESCR, entity_descr_output=ENTITY_DESCR,
count_input=ENTITY_COUNTS, count_input=ENTITY_COUNTS,
prior_prob_input=PRIOR_PROB, prior_prob_input=PRIOR_PROB,
wikidata_input=WIKIDATA_JSON) wikidata_input=WIKIDATA_JSON,
)
print("kb entities:", kb_1.get_size_entities()) print("kb entities:", kb_1.get_size_entities())
print("kb aliases:", kb_1.get_size_aliases()) print("kb aliases:", kb_1.get_size_aliases())
print() print()
print("STEP 3b: write KB and NLP", datetime.datetime.now()) print("STEP 3b: write KB and NLP", now())
kb_1.dump(KB_FILE) kb_1.dump(KB_FILE)
nlp_1.to_disk(NLP_1_DIR) nlp_1.to_disk(NLP_1_DIR)
print() print()
# STEP 4 : read KB back in from file # STEP 4 : read KB back in from file
if to_read_kb: if to_read_kb:
print("STEP 4: to_read_kb", datetime.datetime.now()) print("STEP 4: to_read_kb", now())
nlp_2 = spacy.load(NLP_1_DIR) nlp_2 = spacy.load(NLP_1_DIR)
kb_2 = KnowledgeBase(vocab=nlp_2.vocab, entity_vector_length=DESC_WIDTH) kb_2 = KnowledgeBase(vocab=nlp_2.vocab, entity_vector_length=DESC_WIDTH)
kb_2.load_bulk(KB_FILE) kb_2.load_bulk(KB_FILE)
@ -130,20 +139,26 @@ def run_pipeline():
# STEP 5: create a training dataset from WP # STEP 5: create a training dataset from WP
if create_wp_training: if create_wp_training:
print("STEP 5: create training dataset", datetime.datetime.now()) print("STEP 5: create training dataset", now())
training_set_creator.create_training(wikipedia_input=ENWIKI_DUMP, training_set_creator.create_training(
wikipedia_input=ENWIKI_DUMP,
entity_def_input=ENTITY_DEFS, entity_def_input=ENTITY_DEFS,
training_output=TRAINING_DIR) training_output=TRAINING_DIR,
)
# STEP 6: create and train the entity linking pipe # STEP 6: create and train the entity linking pipe
if train_pipe: if train_pipe:
print("STEP 6: training Entity Linking pipe", datetime.datetime.now()) print("STEP 6: training Entity Linking pipe", now())
type_to_int = {label: i for i, label in enumerate(nlp_2.entity.labels)} type_to_int = {label: i for i, label in enumerate(nlp_2.entity.labels)}
print(" -analysing", len(type_to_int), "different entity types") print(" -analysing", len(type_to_int), "different entity types")
el_pipe = nlp_2.create_pipe(name='entity_linker', el_pipe = nlp_2.create_pipe(
config={"context_width": CONTEXT_WIDTH, name="entity_linker",
config={
"context_width": CONTEXT_WIDTH,
"pretrained_vectors": nlp_2.vocab.vectors.name, "pretrained_vectors": nlp_2.vocab.vectors.name,
"type_to_int": type_to_int}) "type_to_int": type_to_int,
},
)
el_pipe.set_kb(kb_2) el_pipe.set_kb(kb_2)
nlp_2.add_pipe(el_pipe, last=True) nlp_2.add_pipe(el_pipe, last=True)
@ -157,18 +172,22 @@ def run_pipeline():
train_limit = 5000 train_limit = 5000
dev_limit = 5000 dev_limit = 5000
train_data = training_set_creator.read_training(nlp=nlp_2, # for training, get pos & neg instances that correspond to entries in the kb
train_data = training_set_creator.read_training(
nlp=nlp_2,
training_dir=TRAINING_DIR, training_dir=TRAINING_DIR,
dev=False, dev=False,
limit=train_limit) limit=train_limit,
kb=el_pipe.kb,
)
print("Training on", len(train_data), "articles") print("Training on", len(train_data), "articles")
print() print()
dev_data = training_set_creator.read_training(nlp=nlp_2, # for testing, get all pos instances, whether or not they are in the kb
training_dir=TRAINING_DIR, dev_data = training_set_creator.read_training(
dev=True, nlp=nlp_2, training_dir=TRAINING_DIR, dev=True, limit=dev_limit, kb=None
limit=dev_limit) )
print("Dev testing on", len(dev_data), "articles") print("Dev testing on", len(dev_data), "articles")
print() print()
@ -187,8 +206,8 @@ def run_pipeline():
try: try:
docs, golds = zip(*batch) docs, golds = zip(*batch)
nlp_2.update( nlp_2.update(
docs, docs=docs,
golds, golds=golds,
sgd=optimizer, sgd=optimizer,
drop=DROPOUT, drop=DROPOUT,
losses=losses, losses=losses,
@ -200,48 +219,61 @@ def run_pipeline():
if batchnr > 0: if batchnr > 0:
el_pipe.cfg["context_weight"] = 1 el_pipe.cfg["context_weight"] = 1
el_pipe.cfg["prior_weight"] = 1 el_pipe.cfg["prior_weight"] = 1
dev_acc_context, dev_acc_context_dict = _measure_accuracy(dev_data, el_pipe) dev_acc_context, _ = _measure_acc(dev_data, el_pipe)
losses['entity_linker'] = losses['entity_linker'] / batchnr losses["entity_linker"] = losses["entity_linker"] / batchnr
print("Epoch, train loss", itn, round(losses['entity_linker'], 2), print(
" / dev acc avg", round(dev_acc_context, 3)) "Epoch, train loss",
itn,
round(losses["entity_linker"], 2),
" / dev acc avg",
round(dev_acc_context, 3),
)
# STEP 7: measure the performance of our trained pipe on an independent dev set # STEP 7: measure the performance of our trained pipe on an independent dev set
if len(dev_data) and measure_performance: if len(dev_data) and measure_performance:
print() print()
print("STEP 7: performance measurement of Entity Linking pipe", datetime.datetime.now()) print("STEP 7: performance measurement of Entity Linking pipe", now())
print() print()
counts, acc_r, acc_r_label, acc_p, acc_p_label, acc_o, acc_o_label = _measure_baselines(dev_data, kb_2) counts, acc_r, acc_r_d, acc_p, acc_p_d, acc_o, acc_o_d = _measure_baselines(
dev_data, kb_2
)
print("dev counts:", sorted(counts.items(), key=lambda x: x[0])) print("dev counts:", sorted(counts.items(), key=lambda x: x[0]))
print("dev acc oracle:", round(acc_o, 3), [(x, round(y, 3)) for x, y in acc_o_label.items()])
print("dev acc random:", round(acc_r, 3), [(x, round(y, 3)) for x, y in acc_r_label.items()]) oracle_by_label = [(x, round(y, 3)) for x, y in acc_o_d.items()]
print("dev acc prior:", round(acc_p, 3), [(x, round(y, 3)) for x, y in acc_p_label.items()]) print("dev acc oracle:", round(acc_o, 3), oracle_by_label)
random_by_label = [(x, round(y, 3)) for x, y in acc_r_d.items()]
print("dev acc random:", round(acc_r, 3), random_by_label)
prior_by_label = [(x, round(y, 3)) for x, y in acc_p_d.items()]
print("dev acc prior:", round(acc_p, 3), prior_by_label)
# using only context # using only context
el_pipe.cfg["context_weight"] = 1 el_pipe.cfg["context_weight"] = 1
el_pipe.cfg["prior_weight"] = 0 el_pipe.cfg["prior_weight"] = 0
dev_acc_context, dev_acc_context_dict = _measure_accuracy(dev_data, el_pipe) dev_acc_context, dev_acc_cont_d = _measure_acc(dev_data, el_pipe)
print("dev acc context avg:", round(dev_acc_context, 3), context_by_label = [(x, round(y, 3)) for x, y in dev_acc_cont_d.items()]
[(x, round(y, 3)) for x, y in dev_acc_context_dict.items()]) print("dev acc context avg:", round(dev_acc_context, 3), context_by_label)
# measuring combined accuracy (prior + context) # measuring combined accuracy (prior + context)
el_pipe.cfg["context_weight"] = 1 el_pipe.cfg["context_weight"] = 1
el_pipe.cfg["prior_weight"] = 1 el_pipe.cfg["prior_weight"] = 1
dev_acc_combo, dev_acc_combo_dict = _measure_accuracy(dev_data, el_pipe, error_analysis=False) dev_acc_combo, dev_acc_combo_d = _measure_acc(dev_data, el_pipe)
print("dev acc combo avg:", round(dev_acc_combo, 3), combo_by_label = [(x, round(y, 3)) for x, y in dev_acc_combo_d.items()]
[(x, round(y, 3)) for x, y in dev_acc_combo_dict.items()]) print("dev acc combo avg:", round(dev_acc_combo, 3), combo_by_label)
# STEP 8: apply the EL pipe on a toy example # STEP 8: apply the EL pipe on a toy example
if to_test_pipeline: if to_test_pipeline:
print() print()
print("STEP 8: applying Entity Linking to toy example", datetime.datetime.now()) print("STEP 8: applying Entity Linking to toy example", now())
print() print()
run_el_toy_example(nlp=nlp_2) run_el_toy_example(nlp=nlp_2)
# STEP 9: write the NLP pipeline (including entity linker) to file # STEP 9: write the NLP pipeline (including entity linker) to file
if to_write_nlp: if to_write_nlp:
print() print()
print("STEP 9: testing NLP IO", datetime.datetime.now()) print("STEP 9: testing NLP IO", now())
print() print()
print("writing to", NLP_2_DIR) print("writing to", NLP_2_DIR)
nlp_2.to_disk(NLP_2_DIR) nlp_2.to_disk(NLP_2_DIR)
@ -262,23 +294,26 @@ def run_pipeline():
el_pipe = nlp_3.get_pipe("entity_linker") el_pipe = nlp_3.get_pipe("entity_linker")
dev_limit = 5000 dev_limit = 5000
dev_data = training_set_creator.read_training(nlp=nlp_2, dev_data = training_set_creator.read_training(
nlp=nlp_2,
training_dir=TRAINING_DIR, training_dir=TRAINING_DIR,
dev=True, dev=True,
limit=dev_limit) limit=dev_limit,
kb=el_pipe.kb,
)
print("Dev testing from file on", len(dev_data), "articles") print("Dev testing from file on", len(dev_data), "articles")
print() print()
dev_acc_combo, dev_acc_combo_dict = _measure_accuracy(dev_data, el_pipe=el_pipe, error_analysis=False) dev_acc_combo, dev_acc_combo_dict = _measure_acc(dev_data, el_pipe)
print("dev acc combo avg:", round(dev_acc_combo, 3), combo_by_label = [(x, round(y, 3)) for x, y in dev_acc_combo_dict.items()]
[(x, round(y, 3)) for x, y in dev_acc_combo_dict.items()]) print("dev acc combo avg:", round(dev_acc_combo, 3), combo_by_label)
print() print()
print("STOP", datetime.datetime.now()) print("STOP", now())
def _measure_accuracy(data, el_pipe=None, error_analysis=False): def _measure_acc(data, el_pipe=None, error_analysis=False):
# If the docs in the data require further processing with an entity linker, set el_pipe # If the docs in the data require further processing with an entity linker, set el_pipe
correct_by_label = dict() correct_by_label = dict()
incorrect_by_label = dict() incorrect_by_label = dict()
@ -291,7 +326,9 @@ def _measure_accuracy(data, el_pipe=None, error_analysis=False):
for doc, gold in zip(docs, golds): for doc, gold in zip(docs, golds):
try: try:
correct_entries_per_article = dict() correct_entries_per_article = dict()
for entity in gold.links: for entity, value in gold.links.items():
# only evaluating on positive examples
if value:
start, end, gold_kb = entity start, end, gold_kb = entity
correct_entries_per_article[str(start) + "-" + str(end)] = gold_kb correct_entries_per_article[str(start) + "-" + str(end)] = gold_kb
@ -300,7 +337,8 @@ def _measure_accuracy(data, el_pipe=None, error_analysis=False):
pred_entity = ent.kb_id_ pred_entity = ent.kb_id_
start = ent.start_char start = ent.start_char
end = ent.end_char end = ent.end_char
gold_entity = correct_entries_per_article.get(str(start) + "-" + str(end), None) offset = str(start) + "-" + str(end)
gold_entity = correct_entries_per_article.get(offset, None)
# the gold annotations are not complete so we can't evaluate missing annotations as 'wrong' # the gold annotations are not complete so we can't evaluate missing annotations as 'wrong'
if gold_entity is not None: if gold_entity is not None:
if gold_entity == pred_entity: if gold_entity == pred_entity:
@ -311,7 +349,12 @@ def _measure_accuracy(data, el_pipe=None, error_analysis=False):
incorrect_by_label[ent_label] = incorrect + 1 incorrect_by_label[ent_label] = incorrect + 1
if error_analysis: if error_analysis:
print(ent.text, "in", doc) print(ent.text, "in", doc)
print("Predicted", pred_entity, "should have been", gold_entity) print(
"Predicted",
pred_entity,
"should have been",
gold_entity,
)
print() print()
except Exception as e: except Exception as e:
@ -323,16 +366,16 @@ def _measure_accuracy(data, el_pipe=None, error_analysis=False):
def _measure_baselines(data, kb): def _measure_baselines(data, kb):
# Measure 3 performance baselines: random selection, prior probabilities, and 'oracle' prediction for upper bound # Measure 3 performance baselines: random selection, prior probabilities, and 'oracle' prediction for upper bound
counts_by_label = dict() counts_d = dict()
random_correct_by_label = dict() random_correct_d = dict()
random_incorrect_by_label = dict() random_incorrect_d = dict()
oracle_correct_by_label = dict() oracle_correct_d = dict()
oracle_incorrect_by_label = dict() oracle_incorrect_d = dict()
prior_correct_by_label = dict() prior_correct_d = dict()
prior_incorrect_by_label = dict() prior_incorrect_d = dict()
docs = [d for d, g in data if len(d) > 0] docs = [d for d, g in data if len(d) > 0]
golds = [g for d, g in data if len(d) > 0] golds = [g for d, g in data if len(d) > 0]
@ -345,14 +388,15 @@ def _measure_baselines(data, kb):
correct_entries_per_article[str(start) + "-" + str(end)] = gold_kb correct_entries_per_article[str(start) + "-" + str(end)] = gold_kb
for ent in doc.ents: for ent in doc.ents:
ent_label = ent.label_ label = ent.label_
start = ent.start_char start = ent.start_char
end = ent.end_char end = ent.end_char
gold_entity = correct_entries_per_article.get(str(start) + "-" + str(end), None) offset = str(start) + "-" + str(end)
gold_entity = correct_entries_per_article.get(offset, None)
# the gold annotations are not complete so we can't evaluate missing annotations as 'wrong' # the gold annotations are not complete so we can't evaluate missing annotations as 'wrong'
if gold_entity is not None: if gold_entity is not None:
counts_by_label[ent_label] = counts_by_label.get(ent_label, 0) + 1 counts_d[label] = counts_d.get(label, 0) + 1
candidates = kb.get_candidates(ent.text) candidates = kb.get_candidates(ent.text)
oracle_candidate = "" oracle_candidate = ""
best_candidate = "" best_candidate = ""
@ -370,28 +414,36 @@ def _measure_baselines(data, kb):
random_candidate = random.choice(candidates).entity_ random_candidate = random.choice(candidates).entity_
if gold_entity == best_candidate: if gold_entity == best_candidate:
prior_correct_by_label[ent_label] = prior_correct_by_label.get(ent_label, 0) + 1 prior_correct_d[label] = prior_correct_d.get(label, 0) + 1
else: else:
prior_incorrect_by_label[ent_label] = prior_incorrect_by_label.get(ent_label, 0) + 1 prior_incorrect_d[label] = prior_incorrect_d.get(label, 0) + 1
if gold_entity == random_candidate: if gold_entity == random_candidate:
random_correct_by_label[ent_label] = random_correct_by_label.get(ent_label, 0) + 1 random_correct_d[label] = random_correct_d.get(label, 0) + 1
else: else:
random_incorrect_by_label[ent_label] = random_incorrect_by_label.get(ent_label, 0) + 1 random_incorrect_d[label] = random_incorrect_d.get(label, 0) + 1
if gold_entity == oracle_candidate: if gold_entity == oracle_candidate:
oracle_correct_by_label[ent_label] = oracle_correct_by_label.get(ent_label, 0) + 1 oracle_correct_d[label] = oracle_correct_d.get(label, 0) + 1
else: else:
oracle_incorrect_by_label[ent_label] = oracle_incorrect_by_label.get(ent_label, 0) + 1 oracle_incorrect_d[label] = oracle_incorrect_d.get(label, 0) + 1
except Exception as e: except Exception as e:
print("Error assessing accuracy", e) print("Error assessing accuracy", e)
acc_prior, acc_prior_by_label = calculate_acc(prior_correct_by_label, prior_incorrect_by_label) acc_prior, acc_prior_d = calculate_acc(prior_correct_d, prior_incorrect_d)
acc_rand, acc_rand_by_label = calculate_acc(random_correct_by_label, random_incorrect_by_label) acc_rand, acc_rand_d = calculate_acc(random_correct_d, random_incorrect_d)
acc_oracle, acc_oracle_by_label = calculate_acc(oracle_correct_by_label, oracle_incorrect_by_label) acc_oracle, acc_oracle_d = calculate_acc(oracle_correct_d, oracle_incorrect_d)
return counts_by_label, acc_rand, acc_rand_by_label, acc_prior, acc_prior_by_label, acc_oracle, acc_oracle_by_label return (
counts_d,
acc_rand,
acc_rand_d,
acc_prior,
acc_prior_d,
acc_oracle,
acc_oracle_d,
)
def calculate_acc(correct_by_label, incorrect_by_label): def calculate_acc(correct_by_label, incorrect_by_label):
@ -422,15 +474,23 @@ def check_kb(kb):
print("generating candidates for " + mention + " :") print("generating candidates for " + mention + " :")
for c in candidates: for c in candidates:
print(" ", c.prior_prob, c.alias_, "-->", c.entity_ + " (freq=" + str(c.entity_freq) + ")") print(
" ",
c.prior_prob,
c.alias_,
"-->",
c.entity_ + " (freq=" + str(c.entity_freq) + ")",
)
print() print()
def run_el_toy_example(nlp): def run_el_toy_example(nlp):
text = "In The Hitchhiker's Guide to the Galaxy, written by Douglas Adams, " \ text = (
"Douglas reminds us to always bring our towel, even in China or Brazil. " \ "In The Hitchhiker's Guide to the Galaxy, written by Douglas Adams, "
"The main character in Doug's novel is the man Arthur Dent, " \ "Douglas reminds us to always bring our towel, even in China or Brazil. "
"The main character in Doug's novel is the man Arthur Dent, "
"but Douglas doesn't write about George Washington or Homer Simpson." "but Douglas doesn't write about George Washington or Homer Simpson."
)
doc = nlp(text) doc = nlp(text)
print(text) print(text)
for ent in doc.ents: for ent in doc.ents:

View File

@ -208,7 +208,7 @@ cdef class KnowledgeBase:
# Return an empty list if this entity is unknown in this KB # Return an empty list if this entity is unknown in this KB
if entity_hash not in self._entry_index: if entity_hash not in self._entry_index:
return [] return [0] * self.entity_vector_length
entry_index = self._entry_index[entity_hash] entry_index = self._entry_index[entity_hash]
return self._vectors_table[self._entries[entry_index].vector_index] return self._vectors_table[self._entries[entry_index].vector_index]

View File

@ -1151,9 +1151,11 @@ class EntityLinker(Pipe):
ents_by_offset = dict() ents_by_offset = dict()
for ent in doc.ents: for ent in doc.ents:
ents_by_offset[str(ent.start_char) + "_" + str(ent.end_char)] = ent ents_by_offset[str(ent.start_char) + "_" + str(ent.end_char)] = ent
for entity in gold.links: for entity, value in gold.links.items():
start, end, gold_kb = entity start, end, kb_id = entity
mention = doc.text[start:end] mention = doc.text[start:end]
entity_encoding = self.kb.get_vector(kb_id)
prior_prob = self.kb.get_prior_prob(kb_id, mention)
gold_ent = ents_by_offset[str(ent.start_char) + "_" + str(ent.end_char)] gold_ent = ents_by_offset[str(ent.start_char) + "_" + str(ent.end_char)]
assert gold_ent is not None assert gold_ent is not None
@ -1161,24 +1163,17 @@ class EntityLinker(Pipe):
if len(type_to_int) > 0: if len(type_to_int) > 0:
type_vector[type_to_int[gold_ent.label_]] = 1 type_vector[type_to_int[gold_ent.label_]] = 1
candidates = self.kb.get_candidates(mention) # store data
random.shuffle(candidates)
for c in candidates:
kb_id = c.entity_
entity_encoding = c.entity_vector
entity_encodings.append(entity_encoding) entity_encodings.append(entity_encoding)
context_docs.append(doc) context_docs.append(doc)
type_vectors.append(type_vector) type_vectors.append(type_vector)
if self.cfg.get("prior_weight", 1) > 0: if self.cfg.get("prior_weight", 1) > 0:
priors.append([c.prior_prob]) priors.append([prior_prob])
else: else:
priors.append([0]) priors.append([0])
if kb_id == gold_kb: cats.append([value])
cats.append([1])
else:
cats.append([0])
if len(entity_encodings) > 0: if len(entity_encodings) > 0:
assert len(priors) == len(entity_encodings) == len(context_docs) == len(cats) == len(type_vectors) assert len(priors) == len(entity_encodings) == len(context_docs) == len(cats) == len(type_vectors)