mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-26 01:46:28 +03:00
filter training data beforehand (+black formatting)
This commit is contained in:
parent
d833d4c358
commit
ec55d2fccd
|
@ -18,6 +18,10 @@ Gold-standard entities are stored in one file in standoff format (by character o
|
|||
ENTITY_FILE = "gold_entities.csv"
|
||||
|
||||
|
||||
def now():
|
||||
return datetime.datetime.now()
|
||||
|
||||
|
||||
def create_training(wikipedia_input, entity_def_input, training_output):
|
||||
wp_to_id = kb_creator.get_entity_to_id(entity_def_input)
|
||||
_process_wikipedia_texts(wikipedia_input, wp_to_id, training_output, limit=None)
|
||||
|
@ -54,12 +58,7 @@ def _process_wikipedia_texts(wikipedia_input, wp_to_id, training_output, limit=N
|
|||
reading_revision = False
|
||||
while line and (not limit or cnt < limit):
|
||||
if cnt % 1000000 == 0:
|
||||
print(
|
||||
datetime.datetime.now(),
|
||||
"processed",
|
||||
cnt,
|
||||
"lines of Wikipedia dump",
|
||||
)
|
||||
print(now(), "processed", cnt, "lines of Wikipedia dump")
|
||||
clean_line = line.strip().decode("utf-8")
|
||||
|
||||
if clean_line == "<revision>":
|
||||
|
@ -328,8 +327,9 @@ def is_dev(article_id):
|
|||
|
||||
def read_training(nlp, training_dir, dev, limit, kb=None):
|
||||
""" This method provides training examples that correspond to the entity annotations found by the nlp object.
|
||||
When kb is provided, it will include also negative training examples by using the candidate generator.
|
||||
When kb=None, it will only include positive training examples."""
|
||||
When kb is provided (for training), it will include negative training examples by using the candidate generator,
|
||||
and it will only keep positive training examples that can be found in the KB.
|
||||
When kb=None (for testing), it will include all positive examples only."""
|
||||
entityfile_loc = training_dir / ENTITY_FILE
|
||||
data = []
|
||||
|
||||
|
@ -402,12 +402,11 @@ def read_training(nlp, training_dir, dev, limit, kb=None):
|
|||
gold_end = int(end) - found_ent.sent.start_char
|
||||
|
||||
# add both pos and neg examples (in random order)
|
||||
# this will exclude examples not in the KB
|
||||
if kb:
|
||||
gold_entities = {}
|
||||
candidates = kb.get_candidates(alias)
|
||||
candidate_ids = [c.entity_ for c in candidates]
|
||||
# add positive example in case the KB doesn't have it
|
||||
candidate_ids.append(wd_id)
|
||||
random.shuffle(candidate_ids)
|
||||
for kb_id in candidate_ids:
|
||||
entry = (gold_start, gold_end, kb_id)
|
||||
|
@ -415,6 +414,7 @@ def read_training(nlp, training_dir, dev, limit, kb=None):
|
|||
gold_entities[entry] = 0.0
|
||||
else:
|
||||
gold_entities[entry] = 1.0
|
||||
# keep all positive examples
|
||||
else:
|
||||
entry = (gold_start, gold_end, wd_id)
|
||||
gold_entities = {entry: 1.0}
|
||||
|
|
|
@ -14,22 +14,97 @@ Write these results to file for downstream KB and training data generation.
|
|||
map_alias_to_link = dict()
|
||||
|
||||
# these will/should be matched ignoring case
|
||||
wiki_namespaces = ["b", "betawikiversity", "Book", "c", "Category", "Commons",
|
||||
"d", "dbdump", "download", "Draft", "Education", "Foundation",
|
||||
"Gadget", "Gadget definition", "gerrit", "File", "Help", "Image", "Incubator",
|
||||
"m", "mail", "mailarchive", "media", "MediaWiki", "MediaWiki talk", "Mediawikiwiki",
|
||||
"MediaZilla", "Meta", "Metawikipedia", "Module",
|
||||
"mw", "n", "nost", "oldwikisource", "outreach", "outreachwiki", "otrs", "OTRSwiki",
|
||||
"Portal", "phab", "Phabricator", "Project", "q", "quality", "rev",
|
||||
"s", "spcom", "Special", "species", "Strategy", "sulutil", "svn",
|
||||
"Talk", "Template", "Template talk", "Testwiki", "ticket", "TimedText", "Toollabs", "tools",
|
||||
"tswiki", "User", "User talk", "v", "voy",
|
||||
"w", "Wikibooks", "Wikidata", "wikiHow", "Wikinvest", "wikilivres", "Wikimedia", "Wikinews",
|
||||
"Wikipedia", "Wikipedia talk", "Wikiquote", "Wikisource", "Wikispecies", "Wikitech",
|
||||
"Wikiversity", "Wikivoyage", "wikt", "wiktionary", "wmf", "wmania", "WP"]
|
||||
wiki_namespaces = [
|
||||
"b",
|
||||
"betawikiversity",
|
||||
"Book",
|
||||
"c",
|
||||
"Category",
|
||||
"Commons",
|
||||
"d",
|
||||
"dbdump",
|
||||
"download",
|
||||
"Draft",
|
||||
"Education",
|
||||
"Foundation",
|
||||
"Gadget",
|
||||
"Gadget definition",
|
||||
"gerrit",
|
||||
"File",
|
||||
"Help",
|
||||
"Image",
|
||||
"Incubator",
|
||||
"m",
|
||||
"mail",
|
||||
"mailarchive",
|
||||
"media",
|
||||
"MediaWiki",
|
||||
"MediaWiki talk",
|
||||
"Mediawikiwiki",
|
||||
"MediaZilla",
|
||||
"Meta",
|
||||
"Metawikipedia",
|
||||
"Module",
|
||||
"mw",
|
||||
"n",
|
||||
"nost",
|
||||
"oldwikisource",
|
||||
"outreach",
|
||||
"outreachwiki",
|
||||
"otrs",
|
||||
"OTRSwiki",
|
||||
"Portal",
|
||||
"phab",
|
||||
"Phabricator",
|
||||
"Project",
|
||||
"q",
|
||||
"quality",
|
||||
"rev",
|
||||
"s",
|
||||
"spcom",
|
||||
"Special",
|
||||
"species",
|
||||
"Strategy",
|
||||
"sulutil",
|
||||
"svn",
|
||||
"Talk",
|
||||
"Template",
|
||||
"Template talk",
|
||||
"Testwiki",
|
||||
"ticket",
|
||||
"TimedText",
|
||||
"Toollabs",
|
||||
"tools",
|
||||
"tswiki",
|
||||
"User",
|
||||
"User talk",
|
||||
"v",
|
||||
"voy",
|
||||
"w",
|
||||
"Wikibooks",
|
||||
"Wikidata",
|
||||
"wikiHow",
|
||||
"Wikinvest",
|
||||
"wikilivres",
|
||||
"Wikimedia",
|
||||
"Wikinews",
|
||||
"Wikipedia",
|
||||
"Wikipedia talk",
|
||||
"Wikiquote",
|
||||
"Wikisource",
|
||||
"Wikispecies",
|
||||
"Wikitech",
|
||||
"Wikiversity",
|
||||
"Wikivoyage",
|
||||
"wikt",
|
||||
"wiktionary",
|
||||
"wmf",
|
||||
"wmania",
|
||||
"WP",
|
||||
]
|
||||
|
||||
# find the links
|
||||
link_regex = re.compile(r'\[\[[^\[\]]*\]\]')
|
||||
link_regex = re.compile(r"\[\[[^\[\]]*\]\]")
|
||||
|
||||
# match on interwiki links, e.g. `en:` or `:fr:`
|
||||
ns_regex = r":?" + "[a-z][a-z]" + ":"
|
||||
|
@ -41,18 +116,22 @@ for ns in wiki_namespaces:
|
|||
ns_regex = re.compile(ns_regex, re.IGNORECASE)
|
||||
|
||||
|
||||
def read_wikipedia_prior_probs(wikipedia_input, prior_prob_output):
|
||||
def now():
|
||||
return datetime.datetime.now()
|
||||
|
||||
|
||||
def read_prior_probs(wikipedia_input, prior_prob_output):
|
||||
"""
|
||||
Read the XML wikipedia data and parse out intra-wiki links to estimate prior probabilities.
|
||||
The full file takes about 2h to parse 1100M lines.
|
||||
It works relatively fast because it runs line by line, irrelevant of which article the intrawiki is from.
|
||||
"""
|
||||
with bz2.open(wikipedia_input, mode='rb') as file:
|
||||
with bz2.open(wikipedia_input, mode="rb") as file:
|
||||
line = file.readline()
|
||||
cnt = 0
|
||||
while line:
|
||||
if cnt % 5000000 == 0:
|
||||
print(datetime.datetime.now(), "processed", cnt, "lines of Wikipedia dump")
|
||||
print(now(), "processed", cnt, "lines of Wikipedia dump")
|
||||
clean_line = line.strip().decode("utf-8")
|
||||
|
||||
aliases, entities, normalizations = get_wp_links(clean_line)
|
||||
|
@ -64,10 +143,11 @@ def read_wikipedia_prior_probs(wikipedia_input, prior_prob_output):
|
|||
cnt += 1
|
||||
|
||||
# write all aliases and their entities and count occurrences to file
|
||||
with open(prior_prob_output, mode='w', encoding='utf8') as outputfile:
|
||||
with open(prior_prob_output, mode="w", encoding="utf8") as outputfile:
|
||||
outputfile.write("alias" + "|" + "count" + "|" + "entity" + "\n")
|
||||
for alias, alias_dict in sorted(map_alias_to_link.items(), key=lambda x: x[0]):
|
||||
for entity, count in sorted(alias_dict.items(), key=lambda x: x[1], reverse=True):
|
||||
s_dict = sorted(alias_dict.items(), key=lambda x: x[1], reverse=True)
|
||||
for entity, count in s_dict:
|
||||
outputfile.write(alias + "|" + str(count) + "|" + entity + "\n")
|
||||
|
||||
|
||||
|
@ -140,13 +220,13 @@ def write_entity_counts(prior_prob_input, count_output, to_print=False):
|
|||
entity_to_count = dict()
|
||||
total_count = 0
|
||||
|
||||
with open(prior_prob_input, mode='r', encoding='utf8') as prior_file:
|
||||
with open(prior_prob_input, mode="r", encoding="utf8") as prior_file:
|
||||
# skip header
|
||||
prior_file.readline()
|
||||
line = prior_file.readline()
|
||||
|
||||
while line:
|
||||
splits = line.replace('\n', "").split(sep='|')
|
||||
splits = line.replace("\n", "").split(sep="|")
|
||||
# alias = splits[0]
|
||||
count = int(splits[1])
|
||||
entity = splits[2]
|
||||
|
@ -158,7 +238,7 @@ def write_entity_counts(prior_prob_input, count_output, to_print=False):
|
|||
|
||||
line = prior_file.readline()
|
||||
|
||||
with open(count_output, mode='w', encoding='utf8') as entity_file:
|
||||
with open(count_output, mode="w", encoding="utf8") as entity_file:
|
||||
entity_file.write("entity" + "|" + "count" + "\n")
|
||||
for entity, count in entity_to_count.items():
|
||||
entity_file.write(entity + "|" + str(count) + "\n")
|
||||
|
@ -171,12 +251,11 @@ def write_entity_counts(prior_prob_input, count_output, to_print=False):
|
|||
|
||||
def get_all_frequencies(count_input):
|
||||
entity_to_count = dict()
|
||||
with open(count_input, 'r', encoding='utf8') as csvfile:
|
||||
csvreader = csv.reader(csvfile, delimiter='|')
|
||||
with open(count_input, "r", encoding="utf8") as csvfile:
|
||||
csvreader = csv.reader(csvfile, delimiter="|")
|
||||
# skip header
|
||||
next(csvreader)
|
||||
for row in csvreader:
|
||||
entity_to_count[row[0]] = int(row[1])
|
||||
|
||||
return entity_to_count
|
||||
|
||||
|
|
|
@ -5,7 +5,8 @@ import random
|
|||
import datetime
|
||||
from pathlib import Path
|
||||
|
||||
from bin.wiki_entity_linking import training_set_creator, kb_creator, wikipedia_processor as wp
|
||||
from bin.wiki_entity_linking import wikipedia_processor as wp
|
||||
from bin.wiki_entity_linking import training_set_creator, kb_creator
|
||||
from bin.wiki_entity_linking.kb_creator import DESC_WIDTH
|
||||
|
||||
import spacy
|
||||
|
@ -17,23 +18,25 @@ Demonstrate how to build a knowledge base from WikiData and run an Entity Linkin
|
|||
"""
|
||||
|
||||
ROOT_DIR = Path("C:/Users/Sofie/Documents/data/")
|
||||
OUTPUT_DIR = ROOT_DIR / 'wikipedia'
|
||||
TRAINING_DIR = OUTPUT_DIR / 'training_data_nel'
|
||||
OUTPUT_DIR = ROOT_DIR / "wikipedia"
|
||||
TRAINING_DIR = OUTPUT_DIR / "training_data_nel"
|
||||
|
||||
PRIOR_PROB = OUTPUT_DIR / 'prior_prob.csv'
|
||||
ENTITY_COUNTS = OUTPUT_DIR / 'entity_freq.csv'
|
||||
ENTITY_DEFS = OUTPUT_DIR / 'entity_defs.csv'
|
||||
ENTITY_DESCR = OUTPUT_DIR / 'entity_descriptions.csv'
|
||||
PRIOR_PROB = OUTPUT_DIR / "prior_prob.csv"
|
||||
ENTITY_COUNTS = OUTPUT_DIR / "entity_freq.csv"
|
||||
ENTITY_DEFS = OUTPUT_DIR / "entity_defs.csv"
|
||||
ENTITY_DESCR = OUTPUT_DIR / "entity_descriptions.csv"
|
||||
|
||||
KB_FILE = OUTPUT_DIR / 'kb_1' / 'kb'
|
||||
NLP_1_DIR = OUTPUT_DIR / 'nlp_1'
|
||||
NLP_2_DIR = OUTPUT_DIR / 'nlp_2'
|
||||
KB_FILE = OUTPUT_DIR / "kb_1" / "kb"
|
||||
NLP_1_DIR = OUTPUT_DIR / "nlp_1"
|
||||
NLP_2_DIR = OUTPUT_DIR / "nlp_2"
|
||||
|
||||
# get latest-all.json.bz2 from https://dumps.wikimedia.org/wikidatawiki/entities/
|
||||
WIKIDATA_JSON = ROOT_DIR / 'wikidata' / 'wikidata-20190304-all.json.bz2'
|
||||
WIKIDATA_JSON = ROOT_DIR / "wikidata" / "wikidata-20190304-all.json.bz2"
|
||||
|
||||
# get enwiki-latest-pages-articles-multistream.xml.bz2 from https://dumps.wikimedia.org/enwiki/latest/
|
||||
ENWIKI_DUMP = ROOT_DIR / 'wikipedia' / 'enwiki-20190320-pages-articles-multistream.xml.bz2'
|
||||
ENWIKI_DUMP = (
|
||||
ROOT_DIR / "wikipedia" / "enwiki-20190320-pages-articles-multistream.xml.bz2"
|
||||
)
|
||||
|
||||
# KB construction parameters
|
||||
MAX_CANDIDATES = 10
|
||||
|
@ -48,11 +51,15 @@ L2 = 1e-6
|
|||
CONTEXT_WIDTH = 128
|
||||
|
||||
|
||||
def now():
|
||||
return datetime.datetime.now()
|
||||
|
||||
|
||||
def run_pipeline():
|
||||
# set the appropriate booleans to define which parts of the pipeline should be re(run)
|
||||
print("START", datetime.datetime.now())
|
||||
print("START", now())
|
||||
print()
|
||||
nlp_1 = spacy.load('en_core_web_lg')
|
||||
nlp_1 = spacy.load("en_core_web_lg")
|
||||
nlp_2 = None
|
||||
kb_2 = None
|
||||
|
||||
|
@ -82,20 +89,21 @@ def run_pipeline():
|
|||
|
||||
# STEP 1 : create prior probabilities from WP (run only once)
|
||||
if to_create_prior_probs:
|
||||
print("STEP 1: to_create_prior_probs", datetime.datetime.now())
|
||||
wp.read_wikipedia_prior_probs(wikipedia_input=ENWIKI_DUMP, prior_prob_output=PRIOR_PROB)
|
||||
print("STEP 1: to_create_prior_probs", now())
|
||||
wp.read_prior_probs(ENWIKI_DUMP, PRIOR_PROB)
|
||||
print()
|
||||
|
||||
# STEP 2 : deduce entity frequencies from WP (run only once)
|
||||
if to_create_entity_counts:
|
||||
print("STEP 2: to_create_entity_counts", datetime.datetime.now())
|
||||
wp.write_entity_counts(prior_prob_input=PRIOR_PROB, count_output=ENTITY_COUNTS, to_print=False)
|
||||
print("STEP 2: to_create_entity_counts", now())
|
||||
wp.write_entity_counts(PRIOR_PROB, ENTITY_COUNTS, to_print=False)
|
||||
print()
|
||||
|
||||
# STEP 3 : create KB and write to file (run only once)
|
||||
if to_create_kb:
|
||||
print("STEP 3a: to_create_kb", datetime.datetime.now())
|
||||
kb_1 = kb_creator.create_kb(nlp_1,
|
||||
print("STEP 3a: to_create_kb", now())
|
||||
kb_1 = kb_creator.create_kb(
|
||||
nlp=nlp_1,
|
||||
max_entities_per_alias=MAX_CANDIDATES,
|
||||
min_entity_freq=MIN_ENTITY_FREQ,
|
||||
min_occ=MIN_PAIR_OCC,
|
||||
|
@ -103,19 +111,20 @@ def run_pipeline():
|
|||
entity_descr_output=ENTITY_DESCR,
|
||||
count_input=ENTITY_COUNTS,
|
||||
prior_prob_input=PRIOR_PROB,
|
||||
wikidata_input=WIKIDATA_JSON)
|
||||
wikidata_input=WIKIDATA_JSON,
|
||||
)
|
||||
print("kb entities:", kb_1.get_size_entities())
|
||||
print("kb aliases:", kb_1.get_size_aliases())
|
||||
print()
|
||||
|
||||
print("STEP 3b: write KB and NLP", datetime.datetime.now())
|
||||
print("STEP 3b: write KB and NLP", now())
|
||||
kb_1.dump(KB_FILE)
|
||||
nlp_1.to_disk(NLP_1_DIR)
|
||||
print()
|
||||
|
||||
# STEP 4 : read KB back in from file
|
||||
if to_read_kb:
|
||||
print("STEP 4: to_read_kb", datetime.datetime.now())
|
||||
print("STEP 4: to_read_kb", now())
|
||||
nlp_2 = spacy.load(NLP_1_DIR)
|
||||
kb_2 = KnowledgeBase(vocab=nlp_2.vocab, entity_vector_length=DESC_WIDTH)
|
||||
kb_2.load_bulk(KB_FILE)
|
||||
|
@ -130,20 +139,26 @@ def run_pipeline():
|
|||
|
||||
# STEP 5: create a training dataset from WP
|
||||
if create_wp_training:
|
||||
print("STEP 5: create training dataset", datetime.datetime.now())
|
||||
training_set_creator.create_training(wikipedia_input=ENWIKI_DUMP,
|
||||
print("STEP 5: create training dataset", now())
|
||||
training_set_creator.create_training(
|
||||
wikipedia_input=ENWIKI_DUMP,
|
||||
entity_def_input=ENTITY_DEFS,
|
||||
training_output=TRAINING_DIR)
|
||||
training_output=TRAINING_DIR,
|
||||
)
|
||||
|
||||
# STEP 6: create and train the entity linking pipe
|
||||
if train_pipe:
|
||||
print("STEP 6: training Entity Linking pipe", datetime.datetime.now())
|
||||
print("STEP 6: training Entity Linking pipe", now())
|
||||
type_to_int = {label: i for i, label in enumerate(nlp_2.entity.labels)}
|
||||
print(" -analysing", len(type_to_int), "different entity types")
|
||||
el_pipe = nlp_2.create_pipe(name='entity_linker',
|
||||
config={"context_width": CONTEXT_WIDTH,
|
||||
el_pipe = nlp_2.create_pipe(
|
||||
name="entity_linker",
|
||||
config={
|
||||
"context_width": CONTEXT_WIDTH,
|
||||
"pretrained_vectors": nlp_2.vocab.vectors.name,
|
||||
"type_to_int": type_to_int})
|
||||
"type_to_int": type_to_int,
|
||||
},
|
||||
)
|
||||
el_pipe.set_kb(kb_2)
|
||||
nlp_2.add_pipe(el_pipe, last=True)
|
||||
|
||||
|
@ -157,18 +172,22 @@ def run_pipeline():
|
|||
train_limit = 5000
|
||||
dev_limit = 5000
|
||||
|
||||
train_data = training_set_creator.read_training(nlp=nlp_2,
|
||||
# for training, get pos & neg instances that correspond to entries in the kb
|
||||
train_data = training_set_creator.read_training(
|
||||
nlp=nlp_2,
|
||||
training_dir=TRAINING_DIR,
|
||||
dev=False,
|
||||
limit=train_limit)
|
||||
limit=train_limit,
|
||||
kb=el_pipe.kb,
|
||||
)
|
||||
|
||||
print("Training on", len(train_data), "articles")
|
||||
print()
|
||||
|
||||
dev_data = training_set_creator.read_training(nlp=nlp_2,
|
||||
training_dir=TRAINING_DIR,
|
||||
dev=True,
|
||||
limit=dev_limit)
|
||||
# for testing, get all pos instances, whether or not they are in the kb
|
||||
dev_data = training_set_creator.read_training(
|
||||
nlp=nlp_2, training_dir=TRAINING_DIR, dev=True, limit=dev_limit, kb=None
|
||||
)
|
||||
|
||||
print("Dev testing on", len(dev_data), "articles")
|
||||
print()
|
||||
|
@ -187,8 +206,8 @@ def run_pipeline():
|
|||
try:
|
||||
docs, golds = zip(*batch)
|
||||
nlp_2.update(
|
||||
docs,
|
||||
golds,
|
||||
docs=docs,
|
||||
golds=golds,
|
||||
sgd=optimizer,
|
||||
drop=DROPOUT,
|
||||
losses=losses,
|
||||
|
@ -200,48 +219,61 @@ def run_pipeline():
|
|||
if batchnr > 0:
|
||||
el_pipe.cfg["context_weight"] = 1
|
||||
el_pipe.cfg["prior_weight"] = 1
|
||||
dev_acc_context, dev_acc_context_dict = _measure_accuracy(dev_data, el_pipe)
|
||||
losses['entity_linker'] = losses['entity_linker'] / batchnr
|
||||
print("Epoch, train loss", itn, round(losses['entity_linker'], 2),
|
||||
" / dev acc avg", round(dev_acc_context, 3))
|
||||
dev_acc_context, _ = _measure_acc(dev_data, el_pipe)
|
||||
losses["entity_linker"] = losses["entity_linker"] / batchnr
|
||||
print(
|
||||
"Epoch, train loss",
|
||||
itn,
|
||||
round(losses["entity_linker"], 2),
|
||||
" / dev acc avg",
|
||||
round(dev_acc_context, 3),
|
||||
)
|
||||
|
||||
# STEP 7: measure the performance of our trained pipe on an independent dev set
|
||||
if len(dev_data) and measure_performance:
|
||||
print()
|
||||
print("STEP 7: performance measurement of Entity Linking pipe", datetime.datetime.now())
|
||||
print("STEP 7: performance measurement of Entity Linking pipe", now())
|
||||
print()
|
||||
|
||||
counts, acc_r, acc_r_label, acc_p, acc_p_label, acc_o, acc_o_label = _measure_baselines(dev_data, kb_2)
|
||||
counts, acc_r, acc_r_d, acc_p, acc_p_d, acc_o, acc_o_d = _measure_baselines(
|
||||
dev_data, kb_2
|
||||
)
|
||||
print("dev counts:", sorted(counts.items(), key=lambda x: x[0]))
|
||||
print("dev acc oracle:", round(acc_o, 3), [(x, round(y, 3)) for x, y in acc_o_label.items()])
|
||||
print("dev acc random:", round(acc_r, 3), [(x, round(y, 3)) for x, y in acc_r_label.items()])
|
||||
print("dev acc prior:", round(acc_p, 3), [(x, round(y, 3)) for x, y in acc_p_label.items()])
|
||||
|
||||
oracle_by_label = [(x, round(y, 3)) for x, y in acc_o_d.items()]
|
||||
print("dev acc oracle:", round(acc_o, 3), oracle_by_label)
|
||||
|
||||
random_by_label = [(x, round(y, 3)) for x, y in acc_r_d.items()]
|
||||
print("dev acc random:", round(acc_r, 3), random_by_label)
|
||||
|
||||
prior_by_label = [(x, round(y, 3)) for x, y in acc_p_d.items()]
|
||||
print("dev acc prior:", round(acc_p, 3), prior_by_label)
|
||||
|
||||
# using only context
|
||||
el_pipe.cfg["context_weight"] = 1
|
||||
el_pipe.cfg["prior_weight"] = 0
|
||||
dev_acc_context, dev_acc_context_dict = _measure_accuracy(dev_data, el_pipe)
|
||||
print("dev acc context avg:", round(dev_acc_context, 3),
|
||||
[(x, round(y, 3)) for x, y in dev_acc_context_dict.items()])
|
||||
dev_acc_context, dev_acc_cont_d = _measure_acc(dev_data, el_pipe)
|
||||
context_by_label = [(x, round(y, 3)) for x, y in dev_acc_cont_d.items()]
|
||||
print("dev acc context avg:", round(dev_acc_context, 3), context_by_label)
|
||||
|
||||
# measuring combined accuracy (prior + context)
|
||||
el_pipe.cfg["context_weight"] = 1
|
||||
el_pipe.cfg["prior_weight"] = 1
|
||||
dev_acc_combo, dev_acc_combo_dict = _measure_accuracy(dev_data, el_pipe, error_analysis=False)
|
||||
print("dev acc combo avg:", round(dev_acc_combo, 3),
|
||||
[(x, round(y, 3)) for x, y in dev_acc_combo_dict.items()])
|
||||
dev_acc_combo, dev_acc_combo_d = _measure_acc(dev_data, el_pipe)
|
||||
combo_by_label = [(x, round(y, 3)) for x, y in dev_acc_combo_d.items()]
|
||||
print("dev acc combo avg:", round(dev_acc_combo, 3), combo_by_label)
|
||||
|
||||
# STEP 8: apply the EL pipe on a toy example
|
||||
if to_test_pipeline:
|
||||
print()
|
||||
print("STEP 8: applying Entity Linking to toy example", datetime.datetime.now())
|
||||
print("STEP 8: applying Entity Linking to toy example", now())
|
||||
print()
|
||||
run_el_toy_example(nlp=nlp_2)
|
||||
|
||||
# STEP 9: write the NLP pipeline (including entity linker) to file
|
||||
if to_write_nlp:
|
||||
print()
|
||||
print("STEP 9: testing NLP IO", datetime.datetime.now())
|
||||
print("STEP 9: testing NLP IO", now())
|
||||
print()
|
||||
print("writing to", NLP_2_DIR)
|
||||
nlp_2.to_disk(NLP_2_DIR)
|
||||
|
@ -262,23 +294,26 @@ def run_pipeline():
|
|||
el_pipe = nlp_3.get_pipe("entity_linker")
|
||||
|
||||
dev_limit = 5000
|
||||
dev_data = training_set_creator.read_training(nlp=nlp_2,
|
||||
dev_data = training_set_creator.read_training(
|
||||
nlp=nlp_2,
|
||||
training_dir=TRAINING_DIR,
|
||||
dev=True,
|
||||
limit=dev_limit)
|
||||
limit=dev_limit,
|
||||
kb=el_pipe.kb,
|
||||
)
|
||||
|
||||
print("Dev testing from file on", len(dev_data), "articles")
|
||||
print()
|
||||
|
||||
dev_acc_combo, dev_acc_combo_dict = _measure_accuracy(dev_data, el_pipe=el_pipe, error_analysis=False)
|
||||
print("dev acc combo avg:", round(dev_acc_combo, 3),
|
||||
[(x, round(y, 3)) for x, y in dev_acc_combo_dict.items()])
|
||||
dev_acc_combo, dev_acc_combo_dict = _measure_acc(dev_data, el_pipe)
|
||||
combo_by_label = [(x, round(y, 3)) for x, y in dev_acc_combo_dict.items()]
|
||||
print("dev acc combo avg:", round(dev_acc_combo, 3), combo_by_label)
|
||||
|
||||
print()
|
||||
print("STOP", datetime.datetime.now())
|
||||
print("STOP", now())
|
||||
|
||||
|
||||
def _measure_accuracy(data, el_pipe=None, error_analysis=False):
|
||||
def _measure_acc(data, el_pipe=None, error_analysis=False):
|
||||
# If the docs in the data require further processing with an entity linker, set el_pipe
|
||||
correct_by_label = dict()
|
||||
incorrect_by_label = dict()
|
||||
|
@ -291,7 +326,9 @@ def _measure_accuracy(data, el_pipe=None, error_analysis=False):
|
|||
for doc, gold in zip(docs, golds):
|
||||
try:
|
||||
correct_entries_per_article = dict()
|
||||
for entity in gold.links:
|
||||
for entity, value in gold.links.items():
|
||||
# only evaluating on positive examples
|
||||
if value:
|
||||
start, end, gold_kb = entity
|
||||
correct_entries_per_article[str(start) + "-" + str(end)] = gold_kb
|
||||
|
||||
|
@ -300,7 +337,8 @@ def _measure_accuracy(data, el_pipe=None, error_analysis=False):
|
|||
pred_entity = ent.kb_id_
|
||||
start = ent.start_char
|
||||
end = ent.end_char
|
||||
gold_entity = correct_entries_per_article.get(str(start) + "-" + str(end), None)
|
||||
offset = str(start) + "-" + str(end)
|
||||
gold_entity = correct_entries_per_article.get(offset, None)
|
||||
# the gold annotations are not complete so we can't evaluate missing annotations as 'wrong'
|
||||
if gold_entity is not None:
|
||||
if gold_entity == pred_entity:
|
||||
|
@ -311,7 +349,12 @@ def _measure_accuracy(data, el_pipe=None, error_analysis=False):
|
|||
incorrect_by_label[ent_label] = incorrect + 1
|
||||
if error_analysis:
|
||||
print(ent.text, "in", doc)
|
||||
print("Predicted", pred_entity, "should have been", gold_entity)
|
||||
print(
|
||||
"Predicted",
|
||||
pred_entity,
|
||||
"should have been",
|
||||
gold_entity,
|
||||
)
|
||||
print()
|
||||
|
||||
except Exception as e:
|
||||
|
@ -323,16 +366,16 @@ def _measure_accuracy(data, el_pipe=None, error_analysis=False):
|
|||
|
||||
def _measure_baselines(data, kb):
|
||||
# Measure 3 performance baselines: random selection, prior probabilities, and 'oracle' prediction for upper bound
|
||||
counts_by_label = dict()
|
||||
counts_d = dict()
|
||||
|
||||
random_correct_by_label = dict()
|
||||
random_incorrect_by_label = dict()
|
||||
random_correct_d = dict()
|
||||
random_incorrect_d = dict()
|
||||
|
||||
oracle_correct_by_label = dict()
|
||||
oracle_incorrect_by_label = dict()
|
||||
oracle_correct_d = dict()
|
||||
oracle_incorrect_d = dict()
|
||||
|
||||
prior_correct_by_label = dict()
|
||||
prior_incorrect_by_label = dict()
|
||||
prior_correct_d = dict()
|
||||
prior_incorrect_d = dict()
|
||||
|
||||
docs = [d for d, g in data if len(d) > 0]
|
||||
golds = [g for d, g in data if len(d) > 0]
|
||||
|
@ -345,14 +388,15 @@ def _measure_baselines(data, kb):
|
|||
correct_entries_per_article[str(start) + "-" + str(end)] = gold_kb
|
||||
|
||||
for ent in doc.ents:
|
||||
ent_label = ent.label_
|
||||
label = ent.label_
|
||||
start = ent.start_char
|
||||
end = ent.end_char
|
||||
gold_entity = correct_entries_per_article.get(str(start) + "-" + str(end), None)
|
||||
offset = str(start) + "-" + str(end)
|
||||
gold_entity = correct_entries_per_article.get(offset, None)
|
||||
|
||||
# the gold annotations are not complete so we can't evaluate missing annotations as 'wrong'
|
||||
if gold_entity is not None:
|
||||
counts_by_label[ent_label] = counts_by_label.get(ent_label, 0) + 1
|
||||
counts_d[label] = counts_d.get(label, 0) + 1
|
||||
candidates = kb.get_candidates(ent.text)
|
||||
oracle_candidate = ""
|
||||
best_candidate = ""
|
||||
|
@ -370,28 +414,36 @@ def _measure_baselines(data, kb):
|
|||
random_candidate = random.choice(candidates).entity_
|
||||
|
||||
if gold_entity == best_candidate:
|
||||
prior_correct_by_label[ent_label] = prior_correct_by_label.get(ent_label, 0) + 1
|
||||
prior_correct_d[label] = prior_correct_d.get(label, 0) + 1
|
||||
else:
|
||||
prior_incorrect_by_label[ent_label] = prior_incorrect_by_label.get(ent_label, 0) + 1
|
||||
prior_incorrect_d[label] = prior_incorrect_d.get(label, 0) + 1
|
||||
|
||||
if gold_entity == random_candidate:
|
||||
random_correct_by_label[ent_label] = random_correct_by_label.get(ent_label, 0) + 1
|
||||
random_correct_d[label] = random_correct_d.get(label, 0) + 1
|
||||
else:
|
||||
random_incorrect_by_label[ent_label] = random_incorrect_by_label.get(ent_label, 0) + 1
|
||||
random_incorrect_d[label] = random_incorrect_d.get(label, 0) + 1
|
||||
|
||||
if gold_entity == oracle_candidate:
|
||||
oracle_correct_by_label[ent_label] = oracle_correct_by_label.get(ent_label, 0) + 1
|
||||
oracle_correct_d[label] = oracle_correct_d.get(label, 0) + 1
|
||||
else:
|
||||
oracle_incorrect_by_label[ent_label] = oracle_incorrect_by_label.get(ent_label, 0) + 1
|
||||
oracle_incorrect_d[label] = oracle_incorrect_d.get(label, 0) + 1
|
||||
|
||||
except Exception as e:
|
||||
print("Error assessing accuracy", e)
|
||||
|
||||
acc_prior, acc_prior_by_label = calculate_acc(prior_correct_by_label, prior_incorrect_by_label)
|
||||
acc_rand, acc_rand_by_label = calculate_acc(random_correct_by_label, random_incorrect_by_label)
|
||||
acc_oracle, acc_oracle_by_label = calculate_acc(oracle_correct_by_label, oracle_incorrect_by_label)
|
||||
acc_prior, acc_prior_d = calculate_acc(prior_correct_d, prior_incorrect_d)
|
||||
acc_rand, acc_rand_d = calculate_acc(random_correct_d, random_incorrect_d)
|
||||
acc_oracle, acc_oracle_d = calculate_acc(oracle_correct_d, oracle_incorrect_d)
|
||||
|
||||
return counts_by_label, acc_rand, acc_rand_by_label, acc_prior, acc_prior_by_label, acc_oracle, acc_oracle_by_label
|
||||
return (
|
||||
counts_d,
|
||||
acc_rand,
|
||||
acc_rand_d,
|
||||
acc_prior,
|
||||
acc_prior_d,
|
||||
acc_oracle,
|
||||
acc_oracle_d,
|
||||
)
|
||||
|
||||
|
||||
def calculate_acc(correct_by_label, incorrect_by_label):
|
||||
|
@ -422,15 +474,23 @@ def check_kb(kb):
|
|||
|
||||
print("generating candidates for " + mention + " :")
|
||||
for c in candidates:
|
||||
print(" ", c.prior_prob, c.alias_, "-->", c.entity_ + " (freq=" + str(c.entity_freq) + ")")
|
||||
print(
|
||||
" ",
|
||||
c.prior_prob,
|
||||
c.alias_,
|
||||
"-->",
|
||||
c.entity_ + " (freq=" + str(c.entity_freq) + ")",
|
||||
)
|
||||
print()
|
||||
|
||||
|
||||
def run_el_toy_example(nlp):
|
||||
text = "In The Hitchhiker's Guide to the Galaxy, written by Douglas Adams, " \
|
||||
"Douglas reminds us to always bring our towel, even in China or Brazil. " \
|
||||
"The main character in Doug's novel is the man Arthur Dent, " \
|
||||
text = (
|
||||
"In The Hitchhiker's Guide to the Galaxy, written by Douglas Adams, "
|
||||
"Douglas reminds us to always bring our towel, even in China or Brazil. "
|
||||
"The main character in Doug's novel is the man Arthur Dent, "
|
||||
"but Douglas doesn't write about George Washington or Homer Simpson."
|
||||
)
|
||||
doc = nlp(text)
|
||||
print(text)
|
||||
for ent in doc.ents:
|
||||
|
|
|
@ -208,7 +208,7 @@ cdef class KnowledgeBase:
|
|||
|
||||
# Return an empty list if this entity is unknown in this KB
|
||||
if entity_hash not in self._entry_index:
|
||||
return []
|
||||
return [0] * self.entity_vector_length
|
||||
entry_index = self._entry_index[entity_hash]
|
||||
|
||||
return self._vectors_table[self._entries[entry_index].vector_index]
|
||||
|
|
|
@ -1151,9 +1151,11 @@ class EntityLinker(Pipe):
|
|||
ents_by_offset = dict()
|
||||
for ent in doc.ents:
|
||||
ents_by_offset[str(ent.start_char) + "_" + str(ent.end_char)] = ent
|
||||
for entity in gold.links:
|
||||
start, end, gold_kb = entity
|
||||
for entity, value in gold.links.items():
|
||||
start, end, kb_id = entity
|
||||
mention = doc.text[start:end]
|
||||
entity_encoding = self.kb.get_vector(kb_id)
|
||||
prior_prob = self.kb.get_prior_prob(kb_id, mention)
|
||||
|
||||
gold_ent = ents_by_offset[str(ent.start_char) + "_" + str(ent.end_char)]
|
||||
assert gold_ent is not None
|
||||
|
@ -1161,24 +1163,17 @@ class EntityLinker(Pipe):
|
|||
if len(type_to_int) > 0:
|
||||
type_vector[type_to_int[gold_ent.label_]] = 1
|
||||
|
||||
candidates = self.kb.get_candidates(mention)
|
||||
random.shuffle(candidates)
|
||||
for c in candidates:
|
||||
kb_id = c.entity_
|
||||
entity_encoding = c.entity_vector
|
||||
# store data
|
||||
entity_encodings.append(entity_encoding)
|
||||
context_docs.append(doc)
|
||||
type_vectors.append(type_vector)
|
||||
|
||||
if self.cfg.get("prior_weight", 1) > 0:
|
||||
priors.append([c.prior_prob])
|
||||
priors.append([prior_prob])
|
||||
else:
|
||||
priors.append([0])
|
||||
|
||||
if kb_id == gold_kb:
|
||||
cats.append([1])
|
||||
else:
|
||||
cats.append([0])
|
||||
cats.append([value])
|
||||
|
||||
if len(entity_encodings) > 0:
|
||||
assert len(priors) == len(entity_encodings) == len(context_docs) == len(cats) == len(type_vectors)
|
||||
|
|
Loading…
Reference in New Issue
Block a user