2019-05-06 11:56:56 +03:00
|
|
|
# coding: utf-8
|
|
|
|
from __future__ import unicode_literals
|
|
|
|
|
2019-06-07 14:54:45 +03:00
|
|
|
import random
|
|
|
|
|
|
|
|
from spacy.util import minibatch, compounding
|
|
|
|
|
2019-06-18 14:20:40 +03:00
|
|
|
from bin.wiki_entity_linking import training_set_creator, kb_creator, wikipedia_processor as wp
|
|
|
|
from bin.wiki_entity_linking.kb_creator import DESC_WIDTH
|
2019-05-06 11:56:56 +03:00
|
|
|
|
|
|
|
import spacy
|
|
|
|
from spacy.kb import KnowledgeBase
|
|
|
|
import datetime
|
|
|
|
|
|
|
|
"""
|
|
|
|
Demonstrate how to build a knowledge base from WikiData and run an Entity Linking algorithm.
|
|
|
|
"""
|
|
|
|
|
|
|
|
PRIOR_PROB = 'C:/Users/Sofie/Documents/data/wikipedia/prior_prob.csv'
|
|
|
|
ENTITY_COUNTS = 'C:/Users/Sofie/Documents/data/wikipedia/entity_freq.csv'
|
|
|
|
ENTITY_DEFS = 'C:/Users/Sofie/Documents/data/wikipedia/entity_defs.csv'
|
2019-05-07 17:03:42 +03:00
|
|
|
ENTITY_DESCR = 'C:/Users/Sofie/Documents/data/wikipedia/entity_descriptions.csv'
|
2019-05-06 11:56:56 +03:00
|
|
|
|
2019-06-13 23:32:56 +03:00
|
|
|
KB_FILE = 'C:/Users/Sofie/Documents/data/wikipedia/kb_1/kb'
|
2019-06-13 17:25:39 +03:00
|
|
|
NLP_1_DIR = 'C:/Users/Sofie/Documents/data/wikipedia/nlp_1'
|
|
|
|
NLP_2_DIR = 'C:/Users/Sofie/Documents/data/wikipedia/nlp_2'
|
2019-05-06 11:56:56 +03:00
|
|
|
|
2019-05-07 17:03:42 +03:00
|
|
|
TRAINING_DIR = 'C:/Users/Sofie/Documents/data/wikipedia/training_data_nel/'
|
2019-05-06 11:56:56 +03:00
|
|
|
|
2019-06-07 14:54:45 +03:00
|
|
|
MAX_CANDIDATES = 10
|
2019-06-16 22:14:45 +03:00
|
|
|
MIN_ENTITY_FREQ = 20
|
2019-06-07 14:54:45 +03:00
|
|
|
MIN_PAIR_OCC = 5
|
2019-06-18 14:20:40 +03:00
|
|
|
|
2019-06-14 16:55:26 +03:00
|
|
|
EPOCHS = 10
|
2019-06-07 14:54:45 +03:00
|
|
|
DROPOUT = 0.1
|
2019-06-18 14:20:40 +03:00
|
|
|
LEARN_RATE = 0.005
|
|
|
|
L2 = 1e-6
|
2019-05-06 11:56:56 +03:00
|
|
|
|
2019-06-11 12:40:58 +03:00
|
|
|
|
|
|
|
def run_pipeline():
|
2019-05-06 11:56:56 +03:00
|
|
|
print("START", datetime.datetime.now())
|
|
|
|
print()
|
2019-06-13 17:25:39 +03:00
|
|
|
nlp_1 = spacy.load('en_core_web_lg')
|
|
|
|
nlp_2 = None
|
|
|
|
kb_2 = None
|
2019-05-06 11:56:56 +03:00
|
|
|
|
|
|
|
# one-time methods to create KB and write to file
|
|
|
|
to_create_prior_probs = False
|
|
|
|
to_create_entity_counts = False
|
2019-06-16 22:14:45 +03:00
|
|
|
to_create_kb = False
|
2019-05-06 11:56:56 +03:00
|
|
|
|
|
|
|
# read KB back in from file
|
2019-06-16 22:14:45 +03:00
|
|
|
to_read_kb = True
|
2019-06-14 16:55:26 +03:00
|
|
|
to_test_kb = False
|
2019-06-05 01:09:46 +03:00
|
|
|
|
2019-05-06 16:13:50 +03:00
|
|
|
# create training dataset
|
2019-06-14 20:55:46 +03:00
|
|
|
create_wp_training = False
|
2019-05-06 11:56:56 +03:00
|
|
|
|
2019-06-11 12:40:58 +03:00
|
|
|
# train the EL pipe
|
2019-06-16 22:14:45 +03:00
|
|
|
train_pipe = True
|
|
|
|
measure_performance = True
|
2019-06-06 21:22:14 +03:00
|
|
|
|
2019-06-11 12:40:58 +03:00
|
|
|
# test the EL pipe on a simple example
|
2019-06-16 22:14:45 +03:00
|
|
|
to_test_pipeline = True
|
2019-06-06 20:51:27 +03:00
|
|
|
|
2019-06-13 17:25:39 +03:00
|
|
|
# write the NLP object, read back in and test again
|
2019-06-18 01:05:47 +03:00
|
|
|
to_write_nlp = False
|
|
|
|
to_read_nlp = False
|
2019-06-13 17:25:39 +03:00
|
|
|
|
2019-05-06 11:56:56 +03:00
|
|
|
# STEP 1 : create prior probabilities from WP
|
|
|
|
# run only once !
|
|
|
|
if to_create_prior_probs:
|
|
|
|
print("STEP 1: to_create_prior_probs", datetime.datetime.now())
|
|
|
|
wp.read_wikipedia_prior_probs(prior_prob_output=PRIOR_PROB)
|
|
|
|
print()
|
|
|
|
|
|
|
|
# STEP 2 : deduce entity frequencies from WP
|
|
|
|
# run only once !
|
|
|
|
if to_create_entity_counts:
|
|
|
|
print("STEP 2: to_create_entity_counts", datetime.datetime.now())
|
|
|
|
wp.write_entity_counts(prior_prob_input=PRIOR_PROB, count_output=ENTITY_COUNTS, to_print=False)
|
|
|
|
print()
|
|
|
|
|
|
|
|
# STEP 3 : create KB and write to file
|
|
|
|
# run only once !
|
|
|
|
if to_create_kb:
|
|
|
|
print("STEP 3a: to_create_kb", datetime.datetime.now())
|
2019-06-13 17:25:39 +03:00
|
|
|
kb_1 = kb_creator.create_kb(nlp_1,
|
2019-06-14 20:55:46 +03:00
|
|
|
max_entities_per_alias=MAX_CANDIDATES,
|
|
|
|
min_entity_freq=MIN_ENTITY_FREQ,
|
|
|
|
min_occ=MIN_PAIR_OCC,
|
|
|
|
entity_def_output=ENTITY_DEFS,
|
|
|
|
entity_descr_output=ENTITY_DESCR,
|
|
|
|
count_input=ENTITY_COUNTS,
|
|
|
|
prior_prob_input=PRIOR_PROB,
|
|
|
|
to_print=False)
|
2019-06-13 17:25:39 +03:00
|
|
|
print("kb entities:", kb_1.get_size_entities())
|
|
|
|
print("kb aliases:", kb_1.get_size_aliases())
|
2019-05-06 11:56:56 +03:00
|
|
|
print()
|
|
|
|
|
2019-06-13 17:25:39 +03:00
|
|
|
print("STEP 3b: write KB and NLP", datetime.datetime.now())
|
|
|
|
kb_1.dump(KB_FILE)
|
|
|
|
nlp_1.to_disk(NLP_1_DIR)
|
2019-05-06 11:56:56 +03:00
|
|
|
print()
|
|
|
|
|
|
|
|
# STEP 4 : read KB back in from file
|
|
|
|
if to_read_kb:
|
|
|
|
print("STEP 4: to_read_kb", datetime.datetime.now())
|
2019-06-13 17:25:39 +03:00
|
|
|
nlp_2 = spacy.load(NLP_1_DIR)
|
|
|
|
kb_2 = KnowledgeBase(vocab=nlp_2.vocab, entity_vector_length=DESC_WIDTH)
|
|
|
|
kb_2.load_bulk(KB_FILE)
|
|
|
|
print("kb entities:", kb_2.get_size_entities())
|
|
|
|
print("kb aliases:", kb_2.get_size_aliases())
|
2019-05-06 11:56:56 +03:00
|
|
|
print()
|
|
|
|
|
|
|
|
# test KB
|
|
|
|
if to_test_kb:
|
2019-06-18 14:20:40 +03:00
|
|
|
check_kb(kb_2)
|
2019-05-06 11:56:56 +03:00
|
|
|
print()
|
|
|
|
|
|
|
|
# STEP 5: create a training dataset from WP
|
|
|
|
if create_wp_training:
|
|
|
|
print("STEP 5: create training dataset", datetime.datetime.now())
|
2019-06-14 16:55:26 +03:00
|
|
|
training_set_creator.create_training(entity_def_input=ENTITY_DEFS, training_output=TRAINING_DIR)
|
2019-05-07 17:03:42 +03:00
|
|
|
|
2019-06-18 14:20:40 +03:00
|
|
|
# STEP 6: create and train the entity linking pipe
|
|
|
|
el_pipe = nlp_2.create_pipe(name='entity_linker', config={})
|
2019-06-14 20:55:46 +03:00
|
|
|
el_pipe.set_kb(kb_2)
|
|
|
|
nlp_2.add_pipe(el_pipe, last=True)
|
|
|
|
|
|
|
|
other_pipes = [pipe for pipe in nlp_2.pipe_names if pipe != "entity_linker"]
|
|
|
|
with nlp_2.disable_pipes(*other_pipes): # only train Entity Linking
|
2019-06-18 14:20:40 +03:00
|
|
|
optimizer = nlp_2.begin_training()
|
|
|
|
optimizer.learn_rate = LEARN_RATE
|
|
|
|
optimizer.L2 = L2
|
2019-06-14 20:55:46 +03:00
|
|
|
|
2019-06-06 21:22:14 +03:00
|
|
|
if train_pipe:
|
2019-06-12 14:37:05 +03:00
|
|
|
print("STEP 6: training Entity Linking pipe", datetime.datetime.now())
|
2019-06-18 01:05:47 +03:00
|
|
|
train_limit = 25000
|
2019-06-18 14:20:40 +03:00
|
|
|
dev_limit = 5000
|
2019-06-10 22:25:26 +03:00
|
|
|
|
2019-06-13 17:25:39 +03:00
|
|
|
train_data = training_set_creator.read_training(nlp=nlp_2,
|
2019-06-10 22:25:26 +03:00
|
|
|
training_dir=TRAINING_DIR,
|
|
|
|
dev=False,
|
2019-06-16 22:14:45 +03:00
|
|
|
limit=train_limit)
|
|
|
|
|
|
|
|
print("Training on", len(train_data), "articles")
|
|
|
|
print()
|
|
|
|
|
2019-06-18 14:20:40 +03:00
|
|
|
dev_data = training_set_creator.read_training(nlp=nlp_2,
|
|
|
|
training_dir=TRAINING_DIR,
|
|
|
|
dev=True,
|
|
|
|
limit=dev_limit)
|
|
|
|
|
|
|
|
print("Dev testing on", len(dev_data), "articles")
|
|
|
|
print()
|
|
|
|
|
2019-06-16 22:14:45 +03:00
|
|
|
if not train_data:
|
|
|
|
print("Did not find any training data")
|
|
|
|
|
|
|
|
else:
|
|
|
|
for itn in range(EPOCHS):
|
|
|
|
random.shuffle(train_data)
|
|
|
|
losses = {}
|
|
|
|
batches = minibatch(train_data, size=compounding(4.0, 128.0, 1.001))
|
|
|
|
batchnr = 0
|
|
|
|
|
|
|
|
with nlp_2.disable_pipes(*other_pipes):
|
|
|
|
for batch in batches:
|
|
|
|
try:
|
|
|
|
docs, golds = zip(*batch)
|
|
|
|
nlp_2.update(
|
|
|
|
docs,
|
|
|
|
golds,
|
2019-06-18 14:20:40 +03:00
|
|
|
sgd=optimizer,
|
2019-06-16 22:14:45 +03:00
|
|
|
drop=DROPOUT,
|
|
|
|
losses=losses,
|
|
|
|
)
|
|
|
|
batchnr += 1
|
|
|
|
except Exception as e:
|
2019-06-17 15:39:40 +03:00
|
|
|
print("Error updating batch:", e)
|
2019-06-16 22:14:45 +03:00
|
|
|
|
2019-06-17 15:39:40 +03:00
|
|
|
if batchnr > 0:
|
2019-06-18 14:20:40 +03:00
|
|
|
with el_pipe.model.use_params(optimizer.averages):
|
|
|
|
el_pipe.context_weight = 1
|
|
|
|
el_pipe.prior_weight = 0
|
|
|
|
dev_acc_context, dev_acc_context_dict = _measure_accuracy(dev_data, el_pipe)
|
|
|
|
losses['entity_linker'] = losses['entity_linker'] / batchnr
|
|
|
|
print("Epoch, train loss", itn, round(losses['entity_linker'], 2),
|
|
|
|
" / dev acc context avg", round(dev_acc_context, 3))
|
|
|
|
|
|
|
|
# STEP 7: measure the performance of our trained pipe on an independent dev set
|
2019-06-16 22:14:45 +03:00
|
|
|
if len(dev_data) and measure_performance:
|
2019-06-13 17:25:39 +03:00
|
|
|
print()
|
|
|
|
print("STEP 7: performance measurement of Entity Linking pipe", datetime.datetime.now())
|
|
|
|
print()
|
2019-06-12 14:37:05 +03:00
|
|
|
|
2019-06-18 14:20:40 +03:00
|
|
|
acc_r, acc_r_by_label, acc_p, acc_p_by_label, acc_o, acc_o_by_label = _measure_baselines(dev_data, kb_2)
|
|
|
|
print("dev acc oracle:", round(acc_o, 3), [(x, round(y, 3)) for x, y in acc_o_by_label.items()])
|
|
|
|
print("dev acc random:", round(acc_r, 3), [(x, round(y, 3)) for x, y in acc_r_by_label.items()])
|
|
|
|
print("dev acc prior:", round(acc_p, 3), [(x, round(y, 3)) for x, y in acc_p_by_label.items()])
|
|
|
|
|
|
|
|
with el_pipe.model.use_params(optimizer.averages):
|
|
|
|
# measuring combined accuracy (prior + context)
|
|
|
|
el_pipe.context_weight = 1
|
|
|
|
el_pipe.prior_weight = 1
|
|
|
|
dev_acc_combo, dev_acc_combo_dict = _measure_accuracy(dev_data, el_pipe)
|
|
|
|
print("dev acc combo avg:", round(dev_acc_combo, 3),
|
|
|
|
[(x, round(y, 3)) for x, y in dev_acc_combo_dict.items()])
|
|
|
|
|
|
|
|
# using only context
|
|
|
|
el_pipe.context_weight = 1
|
|
|
|
el_pipe.prior_weight = 0
|
|
|
|
dev_acc_context, dev_acc_context_dict = _measure_accuracy(dev_data, el_pipe)
|
|
|
|
print("dev acc context avg:", round(dev_acc_context, 3),
|
|
|
|
[(x, round(y, 3)) for x, y in dev_acc_context_dict.items()])
|
|
|
|
print()
|
2019-06-11 12:40:58 +03:00
|
|
|
|
2019-06-14 16:55:26 +03:00
|
|
|
# reset for follow-up tests
|
|
|
|
el_pipe.context_weight = 1
|
|
|
|
el_pipe.prior_weight = 1
|
|
|
|
|
2019-06-18 14:20:40 +03:00
|
|
|
# STEP 8: apply the EL pipe on a toy example
|
2019-06-06 20:51:27 +03:00
|
|
|
if to_test_pipeline:
|
2019-06-12 23:05:53 +03:00
|
|
|
print()
|
|
|
|
print("STEP 8: applying Entity Linking to toy example", datetime.datetime.now())
|
2019-05-06 16:13:50 +03:00
|
|
|
print()
|
2019-06-13 17:25:39 +03:00
|
|
|
run_el_toy_example(nlp=nlp_2)
|
|
|
|
print()
|
|
|
|
|
2019-06-18 14:20:40 +03:00
|
|
|
# STEP 9: write the NLP pipeline (including entity linker) to file
|
2019-06-17 15:39:40 +03:00
|
|
|
if to_write_nlp:
|
2019-06-13 17:25:39 +03:00
|
|
|
print()
|
|
|
|
print("STEP 9: testing NLP IO", datetime.datetime.now())
|
|
|
|
print()
|
|
|
|
print("writing to", NLP_2_DIR)
|
|
|
|
nlp_2.to_disk(NLP_2_DIR)
|
|
|
|
print()
|
|
|
|
print("reading from", NLP_2_DIR)
|
|
|
|
nlp_3 = spacy.load(NLP_2_DIR)
|
|
|
|
|
2019-06-18 14:20:40 +03:00
|
|
|
# verify that the IO has gone correctly
|
2019-06-17 15:39:40 +03:00
|
|
|
if to_read_nlp:
|
|
|
|
print()
|
|
|
|
print("running toy example with NLP 2")
|
|
|
|
run_el_toy_example(nlp=nlp_3)
|
2019-05-06 11:56:56 +03:00
|
|
|
|
|
|
|
print()
|
|
|
|
print("STOP", datetime.datetime.now())
|
2019-06-11 12:40:58 +03:00
|
|
|
|
|
|
|
|
2019-06-11 15:18:20 +03:00
|
|
|
def _measure_accuracy(data, el_pipe):
|
2019-06-14 20:55:46 +03:00
|
|
|
correct_by_label = dict()
|
|
|
|
incorrect_by_label = dict()
|
2019-06-11 12:40:58 +03:00
|
|
|
|
2019-06-12 23:05:53 +03:00
|
|
|
docs = [d for d, g in data if len(d) > 0]
|
2019-06-11 15:18:20 +03:00
|
|
|
docs = el_pipe.pipe(docs)
|
2019-06-12 23:05:53 +03:00
|
|
|
golds = [g for d, g in data if len(d) > 0]
|
2019-06-11 12:40:58 +03:00
|
|
|
|
|
|
|
for doc, gold in zip(docs, golds):
|
2019-06-12 14:37:05 +03:00
|
|
|
try:
|
|
|
|
correct_entries_per_article = dict()
|
|
|
|
for entity in gold.links:
|
|
|
|
start, end, gold_kb = entity
|
|
|
|
correct_entries_per_article[str(start) + "-" + str(end)] = gold_kb
|
|
|
|
|
|
|
|
for ent in doc.ents:
|
2019-06-14 20:55:46 +03:00
|
|
|
ent_label = ent.label_
|
|
|
|
pred_entity = ent.kb_id_
|
|
|
|
start = ent.start_char
|
|
|
|
end = ent.end_char
|
|
|
|
gold_entity = correct_entries_per_article.get(str(start) + "-" + str(end), None)
|
|
|
|
# the gold annotations are not complete so we can't evaluate missing annotations as 'wrong'
|
|
|
|
if gold_entity is not None:
|
|
|
|
if gold_entity == pred_entity:
|
|
|
|
correct = correct_by_label.get(ent_label, 0)
|
|
|
|
correct_by_label[ent_label] = correct + 1
|
|
|
|
else:
|
|
|
|
incorrect = incorrect_by_label.get(ent_label, 0)
|
|
|
|
incorrect_by_label[ent_label] = incorrect + 1
|
2019-06-12 14:37:05 +03:00
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
print("Error assessing accuracy", e)
|
2019-06-11 12:40:58 +03:00
|
|
|
|
2019-06-17 15:39:40 +03:00
|
|
|
acc, acc_by_label = calculate_acc(correct_by_label, incorrect_by_label)
|
|
|
|
return acc, acc_by_label
|
|
|
|
|
|
|
|
|
|
|
|
def _measure_baselines(data, kb):
|
2019-06-18 14:20:40 +03:00
|
|
|
# Measure 3 performance baselines: random selection, prior probabilities, and 'oracle' prediction for upper bound
|
2019-06-17 15:39:40 +03:00
|
|
|
random_correct_by_label = dict()
|
|
|
|
random_incorrect_by_label = dict()
|
|
|
|
|
|
|
|
oracle_correct_by_label = dict()
|
|
|
|
oracle_incorrect_by_label = dict()
|
|
|
|
|
|
|
|
prior_correct_by_label = dict()
|
|
|
|
prior_incorrect_by_label = dict()
|
|
|
|
|
|
|
|
docs = [d for d, g in data if len(d) > 0]
|
|
|
|
golds = [g for d, g in data if len(d) > 0]
|
|
|
|
|
|
|
|
for doc, gold in zip(docs, golds):
|
|
|
|
try:
|
|
|
|
correct_entries_per_article = dict()
|
|
|
|
for entity in gold.links:
|
|
|
|
start, end, gold_kb = entity
|
|
|
|
correct_entries_per_article[str(start) + "-" + str(end)] = gold_kb
|
|
|
|
|
|
|
|
for ent in doc.ents:
|
|
|
|
ent_label = ent.label_
|
|
|
|
start = ent.start_char
|
|
|
|
end = ent.end_char
|
|
|
|
gold_entity = correct_entries_per_article.get(str(start) + "-" + str(end), None)
|
|
|
|
|
|
|
|
# the gold annotations are not complete so we can't evaluate missing annotations as 'wrong'
|
|
|
|
if gold_entity is not None:
|
|
|
|
candidates = kb.get_candidates(ent.text)
|
|
|
|
oracle_candidate = ""
|
|
|
|
best_candidate = ""
|
|
|
|
random_candidate = ""
|
|
|
|
if candidates:
|
|
|
|
scores = list()
|
|
|
|
|
|
|
|
for c in candidates:
|
|
|
|
scores.append(c.prior_prob)
|
|
|
|
if c.entity_ == gold_entity:
|
|
|
|
oracle_candidate = c.entity_
|
|
|
|
|
|
|
|
best_index = scores.index(max(scores))
|
|
|
|
best_candidate = candidates[best_index].entity_
|
|
|
|
random_candidate = random.choice(candidates).entity_
|
|
|
|
|
|
|
|
if gold_entity == best_candidate:
|
|
|
|
prior_correct_by_label[ent_label] = prior_correct_by_label.get(ent_label, 0) + 1
|
|
|
|
else:
|
|
|
|
prior_incorrect_by_label[ent_label] = prior_incorrect_by_label.get(ent_label, 0) + 1
|
|
|
|
|
|
|
|
if gold_entity == random_candidate:
|
|
|
|
random_correct_by_label[ent_label] = random_correct_by_label.get(ent_label, 0) + 1
|
|
|
|
else:
|
|
|
|
random_incorrect_by_label[ent_label] = random_incorrect_by_label.get(ent_label, 0) + 1
|
|
|
|
|
|
|
|
if gold_entity == oracle_candidate:
|
|
|
|
oracle_correct_by_label[ent_label] = oracle_correct_by_label.get(ent_label, 0) + 1
|
|
|
|
else:
|
|
|
|
oracle_incorrect_by_label[ent_label] = oracle_incorrect_by_label.get(ent_label, 0) + 1
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
print("Error assessing accuracy", e)
|
|
|
|
|
|
|
|
acc_prior, acc_prior_by_label = calculate_acc(prior_correct_by_label, prior_incorrect_by_label)
|
|
|
|
acc_random, acc_random_by_label = calculate_acc(random_correct_by_label, random_incorrect_by_label)
|
|
|
|
acc_oracle, acc_oracle_by_label = calculate_acc(oracle_correct_by_label, oracle_incorrect_by_label)
|
|
|
|
|
|
|
|
return acc_random, acc_random_by_label, acc_prior, acc_prior_by_label, acc_oracle, acc_oracle_by_label
|
|
|
|
|
|
|
|
|
|
|
|
def calculate_acc(correct_by_label, incorrect_by_label):
|
2019-06-14 20:55:46 +03:00
|
|
|
acc_by_label = dict()
|
|
|
|
total_correct = 0
|
|
|
|
total_incorrect = 0
|
2019-06-18 01:05:47 +03:00
|
|
|
all_keys = set()
|
|
|
|
all_keys.update(correct_by_label.keys())
|
|
|
|
all_keys.update(incorrect_by_label.keys())
|
|
|
|
for label in sorted(all_keys):
|
|
|
|
correct = correct_by_label.get(label, 0)
|
2019-06-14 20:55:46 +03:00
|
|
|
incorrect = incorrect_by_label.get(label, 0)
|
|
|
|
total_correct += correct
|
|
|
|
total_incorrect += incorrect
|
|
|
|
if correct == incorrect == 0:
|
|
|
|
acc_by_label[label] = 0
|
|
|
|
else:
|
|
|
|
acc_by_label[label] = correct / (correct + incorrect)
|
|
|
|
acc = 0
|
|
|
|
if not (total_correct == total_incorrect == 0):
|
|
|
|
acc = total_correct / (total_correct + total_incorrect)
|
|
|
|
return acc, acc_by_label
|
|
|
|
|
|
|
|
|
2019-06-18 14:20:40 +03:00
|
|
|
def check_kb(kb):
|
2019-06-14 20:55:46 +03:00
|
|
|
for mention in ("Bush", "Douglas Adams", "Homer", "Brazil", "China"):
|
|
|
|
candidates = kb.get_candidates(mention)
|
|
|
|
|
|
|
|
print("generating candidates for " + mention + " :")
|
|
|
|
for c in candidates:
|
|
|
|
print(" ", c.prior_prob, c.alias_, "-->", c.entity_ + " (freq=" + str(c.entity_freq) + ")")
|
|
|
|
print()
|
2019-06-11 12:40:58 +03:00
|
|
|
|
|
|
|
|
2019-06-13 17:25:39 +03:00
|
|
|
def run_el_toy_example(nlp):
|
2019-06-11 12:40:58 +03:00
|
|
|
text = "In The Hitchhiker's Guide to the Galaxy, written by Douglas Adams, " \
|
2019-06-14 20:55:46 +03:00
|
|
|
"Douglas reminds us to always bring our towel, even in China or Brazil. " \
|
2019-06-11 12:40:58 +03:00
|
|
|
"The main character in Doug's novel is the man Arthur Dent, " \
|
|
|
|
"but Douglas doesn't write about George Washington or Homer Simpson."
|
|
|
|
doc = nlp(text)
|
2019-06-17 15:39:40 +03:00
|
|
|
print(text)
|
2019-06-11 12:40:58 +03:00
|
|
|
for ent in doc.ents:
|
|
|
|
print("ent", ent.text, ent.label_, ent.kb_id_)
|
|
|
|
print()
|
|
|
|
|
2019-06-17 15:39:40 +03:00
|
|
|
# Q4426480 is her husband
|
2019-06-18 14:20:40 +03:00
|
|
|
text = "Ada Lovelace was the countess of Lovelace. She's known for her programming work on the analytical engine. "\
|
2019-06-17 15:39:40 +03:00
|
|
|
"She loved her husband William King dearly. "
|
2019-06-11 12:40:58 +03:00
|
|
|
doc = nlp(text)
|
2019-06-17 15:39:40 +03:00
|
|
|
print(text)
|
|
|
|
for ent in doc.ents:
|
|
|
|
print("ent", ent.text, ent.label_, ent.kb_id_)
|
|
|
|
print()
|
2019-06-11 12:40:58 +03:00
|
|
|
|
2019-06-17 15:39:40 +03:00
|
|
|
# Q3568763 is her tutor
|
2019-06-18 14:20:40 +03:00
|
|
|
text = "Ada Lovelace was the countess of Lovelace. She's known for her programming work on the analytical engine. "\
|
2019-06-17 15:39:40 +03:00
|
|
|
"She was tutored by her favorite physics tutor William King."
|
|
|
|
doc = nlp(text)
|
|
|
|
print(text)
|
2019-06-11 12:40:58 +03:00
|
|
|
for ent in doc.ents:
|
|
|
|
print("ent", ent.text, ent.label_, ent.kb_id_)
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
2019-06-13 17:25:39 +03:00
|
|
|
run_pipeline()
|