mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-12 10:16:27 +03:00
Merge branch 'master' into develop
This commit is contained in:
commit
27106d6528
|
@ -0,0 +1,11 @@
|
||||||
|
TRAINING_DATA_FILE = "gold_entities.jsonl"
|
||||||
|
KB_FILE = "kb"
|
||||||
|
KB_MODEL_DIR = "nlp_kb"
|
||||||
|
OUTPUT_MODEL_DIR = "nlp"
|
||||||
|
|
||||||
|
PRIOR_PROB_PATH = "prior_prob.csv"
|
||||||
|
ENTITY_DEFS_PATH = "entity_defs.csv"
|
||||||
|
ENTITY_FREQ_PATH = "entity_freq.csv"
|
||||||
|
ENTITY_DESCR_PATH = "entity_descriptions.csv"
|
||||||
|
|
||||||
|
LOG_FORMAT = '%(asctime)s - %(levelname)s - %(name)s - %(message)s'
|
200
bin/wiki_entity_linking/entity_linker_evaluation.py
Normal file
200
bin/wiki_entity_linking/entity_linker_evaluation.py
Normal file
|
@ -0,0 +1,200 @@
|
||||||
|
import logging
|
||||||
|
import random
|
||||||
|
|
||||||
|
from collections import defaultdict
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class Metrics(object):
|
||||||
|
true_pos = 0
|
||||||
|
false_pos = 0
|
||||||
|
false_neg = 0
|
||||||
|
|
||||||
|
def update_results(self, true_entity, candidate):
|
||||||
|
candidate_is_correct = true_entity == candidate
|
||||||
|
|
||||||
|
# Assume that we have no labeled negatives in the data (i.e. cases where true_entity is "NIL")
|
||||||
|
# Therefore, if candidate_is_correct then we have a true positive and never a true negative
|
||||||
|
self.true_pos += candidate_is_correct
|
||||||
|
self.false_neg += not candidate_is_correct
|
||||||
|
if candidate not in {"", "NIL"}:
|
||||||
|
self.false_pos += not candidate_is_correct
|
||||||
|
|
||||||
|
def calculate_precision(self):
|
||||||
|
if self.true_pos == 0:
|
||||||
|
return 0.0
|
||||||
|
else:
|
||||||
|
return self.true_pos / (self.true_pos + self.false_pos)
|
||||||
|
|
||||||
|
def calculate_recall(self):
|
||||||
|
if self.true_pos == 0:
|
||||||
|
return 0.0
|
||||||
|
else:
|
||||||
|
return self.true_pos / (self.true_pos + self.false_neg)
|
||||||
|
|
||||||
|
|
||||||
|
class EvaluationResults(object):
|
||||||
|
def __init__(self):
|
||||||
|
self.metrics = Metrics()
|
||||||
|
self.metrics_by_label = defaultdict(Metrics)
|
||||||
|
|
||||||
|
def update_metrics(self, ent_label, true_entity, candidate):
|
||||||
|
self.metrics.update_results(true_entity, candidate)
|
||||||
|
self.metrics_by_label[ent_label].update_results(true_entity, candidate)
|
||||||
|
|
||||||
|
def increment_false_negatives(self):
|
||||||
|
self.metrics.false_neg += 1
|
||||||
|
|
||||||
|
def report_metrics(self, model_name):
|
||||||
|
model_str = model_name.title()
|
||||||
|
recall = self.metrics.calculate_recall()
|
||||||
|
precision = self.metrics.calculate_precision()
|
||||||
|
return ("{}: ".format(model_str) +
|
||||||
|
"Recall = {} | ".format(round(recall, 3)) +
|
||||||
|
"Precision = {} | ".format(round(precision, 3)) +
|
||||||
|
"Precision by label = {}".format({k: v.calculate_precision()
|
||||||
|
for k, v in self.metrics_by_label.items()}))
|
||||||
|
|
||||||
|
|
||||||
|
class BaselineResults(object):
|
||||||
|
def __init__(self):
|
||||||
|
self.random = EvaluationResults()
|
||||||
|
self.prior = EvaluationResults()
|
||||||
|
self.oracle = EvaluationResults()
|
||||||
|
|
||||||
|
def report_accuracy(self, model):
|
||||||
|
results = getattr(self, model)
|
||||||
|
return results.report_metrics(model)
|
||||||
|
|
||||||
|
def update_baselines(self, true_entity, ent_label, random_candidate, prior_candidate, oracle_candidate):
|
||||||
|
self.oracle.update_metrics(ent_label, true_entity, oracle_candidate)
|
||||||
|
self.prior.update_metrics(ent_label, true_entity, prior_candidate)
|
||||||
|
self.random.update_metrics(ent_label, true_entity, random_candidate)
|
||||||
|
|
||||||
|
|
||||||
|
def measure_performance(dev_data, kb, el_pipe):
|
||||||
|
baseline_accuracies = measure_baselines(
|
||||||
|
dev_data, kb
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(baseline_accuracies.report_accuracy("random"))
|
||||||
|
logger.info(baseline_accuracies.report_accuracy("prior"))
|
||||||
|
logger.info(baseline_accuracies.report_accuracy("oracle"))
|
||||||
|
|
||||||
|
# using only context
|
||||||
|
el_pipe.cfg["incl_context"] = True
|
||||||
|
el_pipe.cfg["incl_prior"] = False
|
||||||
|
results = get_eval_results(dev_data, el_pipe)
|
||||||
|
logger.info(results.report_metrics("context only"))
|
||||||
|
|
||||||
|
# measuring combined accuracy (prior + context)
|
||||||
|
el_pipe.cfg["incl_context"] = True
|
||||||
|
el_pipe.cfg["incl_prior"] = True
|
||||||
|
results = get_eval_results(dev_data, el_pipe)
|
||||||
|
logger.info(results.report_metrics("context and prior"))
|
||||||
|
|
||||||
|
|
||||||
|
def get_eval_results(data, el_pipe=None):
|
||||||
|
# If the docs in the data require further processing with an entity linker, set el_pipe
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
docs = []
|
||||||
|
golds = []
|
||||||
|
for d, g in tqdm(data, leave=False):
|
||||||
|
if len(d) > 0:
|
||||||
|
golds.append(g)
|
||||||
|
if el_pipe is not None:
|
||||||
|
docs.append(el_pipe(d))
|
||||||
|
else:
|
||||||
|
docs.append(d)
|
||||||
|
|
||||||
|
results = EvaluationResults()
|
||||||
|
for doc, gold in zip(docs, golds):
|
||||||
|
tagged_entries_per_article = {_offset(ent.start_char, ent.end_char): ent for ent in doc.ents}
|
||||||
|
try:
|
||||||
|
correct_entries_per_article = dict()
|
||||||
|
for entity, kb_dict in gold.links.items():
|
||||||
|
start, end = entity
|
||||||
|
# only evaluating on positive examples
|
||||||
|
for gold_kb, value in kb_dict.items():
|
||||||
|
if value:
|
||||||
|
offset = _offset(start, end)
|
||||||
|
correct_entries_per_article[offset] = gold_kb
|
||||||
|
if offset not in tagged_entries_per_article:
|
||||||
|
results.increment_false_negatives()
|
||||||
|
|
||||||
|
for ent in doc.ents:
|
||||||
|
ent_label = ent.label_
|
||||||
|
pred_entity = ent.kb_id_
|
||||||
|
start = ent.start_char
|
||||||
|
end = ent.end_char
|
||||||
|
offset = _offset(start, end)
|
||||||
|
gold_entity = correct_entries_per_article.get(offset, None)
|
||||||
|
# the gold annotations are not complete so we can't evaluate missing annotations as 'wrong'
|
||||||
|
if gold_entity is not None:
|
||||||
|
results.update_metrics(ent_label, gold_entity, pred_entity)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logging.error("Error assessing accuracy " + str(e))
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
def measure_baselines(data, kb):
|
||||||
|
# Measure 3 performance baselines: random selection, prior probabilities, and 'oracle' prediction for upper bound
|
||||||
|
counts_d = dict()
|
||||||
|
|
||||||
|
baseline_results = BaselineResults()
|
||||||
|
|
||||||
|
docs = [d for d, g in data if len(d) > 0]
|
||||||
|
golds = [g for d, g in data if len(d) > 0]
|
||||||
|
|
||||||
|
for doc, gold in zip(docs, golds):
|
||||||
|
correct_entries_per_article = dict()
|
||||||
|
tagged_entries_per_article = {_offset(ent.start_char, ent.end_char): ent for ent in doc.ents}
|
||||||
|
for entity, kb_dict in gold.links.items():
|
||||||
|
start, end = entity
|
||||||
|
for gold_kb, value in kb_dict.items():
|
||||||
|
# only evaluating on positive examples
|
||||||
|
if value:
|
||||||
|
offset = _offset(start, end)
|
||||||
|
correct_entries_per_article[offset] = gold_kb
|
||||||
|
if offset not in tagged_entries_per_article:
|
||||||
|
baseline_results.random.increment_false_negatives()
|
||||||
|
baseline_results.oracle.increment_false_negatives()
|
||||||
|
baseline_results.prior.increment_false_negatives()
|
||||||
|
|
||||||
|
for ent in doc.ents:
|
||||||
|
ent_label = ent.label_
|
||||||
|
start = ent.start_char
|
||||||
|
end = ent.end_char
|
||||||
|
offset = _offset(start, end)
|
||||||
|
gold_entity = correct_entries_per_article.get(offset, None)
|
||||||
|
|
||||||
|
# the gold annotations are not complete so we can't evaluate missing annotations as 'wrong'
|
||||||
|
if gold_entity is not None:
|
||||||
|
candidates = kb.get_candidates(ent.text)
|
||||||
|
oracle_candidate = ""
|
||||||
|
best_candidate = ""
|
||||||
|
random_candidate = ""
|
||||||
|
if candidates:
|
||||||
|
scores = []
|
||||||
|
|
||||||
|
for c in candidates:
|
||||||
|
scores.append(c.prior_prob)
|
||||||
|
if c.entity_ == gold_entity:
|
||||||
|
oracle_candidate = c.entity_
|
||||||
|
|
||||||
|
best_index = scores.index(max(scores))
|
||||||
|
best_candidate = candidates[best_index].entity_
|
||||||
|
random_candidate = random.choice(candidates).entity_
|
||||||
|
|
||||||
|
baseline_results.update_baselines(gold_entity, ent_label,
|
||||||
|
random_candidate, best_candidate, oracle_candidate)
|
||||||
|
|
||||||
|
return baseline_results
|
||||||
|
|
||||||
|
|
||||||
|
def _offset(start, end):
|
||||||
|
return "{}_{}".format(start, end)
|
|
@ -1,12 +1,20 @@
|
||||||
# coding: utf-8
|
# coding: utf-8
|
||||||
from __future__ import unicode_literals
|
from __future__ import unicode_literals
|
||||||
|
|
||||||
from bin.wiki_entity_linking.train_descriptions import EntityEncoder
|
import csv
|
||||||
from bin.wiki_entity_linking import wikidata_processor as wd, wikipedia_processor as wp
|
import logging
|
||||||
|
import spacy
|
||||||
|
import sys
|
||||||
|
|
||||||
from spacy.kb import KnowledgeBase
|
from spacy.kb import KnowledgeBase
|
||||||
|
|
||||||
import csv
|
from bin.wiki_entity_linking import wikipedia_processor as wp
|
||||||
import datetime
|
from bin.wiki_entity_linking.train_descriptions import EntityEncoder
|
||||||
|
|
||||||
|
csv.field_size_limit(sys.maxsize)
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def create_kb(
|
def create_kb(
|
||||||
|
@ -14,52 +22,73 @@ def create_kb(
|
||||||
max_entities_per_alias,
|
max_entities_per_alias,
|
||||||
min_entity_freq,
|
min_entity_freq,
|
||||||
min_occ,
|
min_occ,
|
||||||
entity_def_output,
|
entity_def_input,
|
||||||
entity_descr_output,
|
entity_descr_path,
|
||||||
count_input,
|
count_input,
|
||||||
prior_prob_input,
|
prior_prob_input,
|
||||||
wikidata_input,
|
|
||||||
entity_vector_length,
|
entity_vector_length,
|
||||||
limit=None,
|
|
||||||
read_raw_data=True,
|
|
||||||
):
|
):
|
||||||
# Create the knowledge base from Wikidata entries
|
# Create the knowledge base from Wikidata entries
|
||||||
kb = KnowledgeBase(vocab=nlp.vocab, entity_vector_length=entity_vector_length)
|
kb = KnowledgeBase(vocab=nlp.vocab, entity_vector_length=entity_vector_length)
|
||||||
|
|
||||||
|
# read the mappings from file
|
||||||
|
title_to_id = get_entity_to_id(entity_def_input)
|
||||||
|
id_to_descr = get_id_to_description(entity_descr_path)
|
||||||
|
|
||||||
# check the length of the nlp vectors
|
# check the length of the nlp vectors
|
||||||
if "vectors" in nlp.meta and nlp.vocab.vectors.size:
|
if "vectors" in nlp.meta and nlp.vocab.vectors.size:
|
||||||
input_dim = nlp.vocab.vectors_length
|
input_dim = nlp.vocab.vectors_length
|
||||||
print("Loaded pre-trained vectors of size %s" % input_dim)
|
logger.info("Loaded pre-trained vectors of size %s" % input_dim)
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"The `nlp` object should have access to pre-trained word vectors, "
|
"The `nlp` object should have access to pre-trained word vectors, "
|
||||||
" cf. https://spacy.io/usage/models#languages."
|
" cf. https://spacy.io/usage/models#languages."
|
||||||
)
|
)
|
||||||
|
|
||||||
# disable this part of the pipeline when rerunning the KB generation from preprocessed files
|
logger.info("Get entity frequencies")
|
||||||
if read_raw_data:
|
|
||||||
print()
|
|
||||||
print(now(), " * read wikidata entities:")
|
|
||||||
title_to_id, id_to_descr = wd.read_wikidata_entities_json(
|
|
||||||
wikidata_input, limit=limit
|
|
||||||
)
|
|
||||||
|
|
||||||
# write the title-ID and ID-description mappings to file
|
|
||||||
_write_entity_files(
|
|
||||||
entity_def_output, entity_descr_output, title_to_id, id_to_descr
|
|
||||||
)
|
|
||||||
|
|
||||||
else:
|
|
||||||
# read the mappings from file
|
|
||||||
title_to_id = get_entity_to_id(entity_def_output)
|
|
||||||
id_to_descr = get_id_to_description(entity_descr_output)
|
|
||||||
|
|
||||||
print()
|
|
||||||
print(now(), " * get entity frequencies:")
|
|
||||||
print()
|
|
||||||
entity_frequencies = wp.get_all_frequencies(count_input=count_input)
|
entity_frequencies = wp.get_all_frequencies(count_input=count_input)
|
||||||
|
|
||||||
|
logger.info("Filtering entities with fewer than {} mentions".format(min_entity_freq))
|
||||||
# filter the entities for in the KB by frequency, because there's just too much data (8M entities) otherwise
|
# filter the entities for in the KB by frequency, because there's just too much data (8M entities) otherwise
|
||||||
|
filtered_title_to_id, entity_list, description_list, frequency_list = get_filtered_entities(
|
||||||
|
title_to_id,
|
||||||
|
id_to_descr,
|
||||||
|
entity_frequencies,
|
||||||
|
min_entity_freq
|
||||||
|
)
|
||||||
|
logger.info("Left with {} entities".format(len(description_list)))
|
||||||
|
|
||||||
|
logger.info("Train entity encoder")
|
||||||
|
encoder = EntityEncoder(nlp, input_dim, entity_vector_length)
|
||||||
|
encoder.train(description_list=description_list, to_print=True)
|
||||||
|
|
||||||
|
logger.info("Get entity embeddings:")
|
||||||
|
embeddings = encoder.apply_encoder(description_list)
|
||||||
|
|
||||||
|
logger.info("Adding {} entities".format(len(entity_list)))
|
||||||
|
kb.set_entities(
|
||||||
|
entity_list=entity_list, freq_list=frequency_list, vector_list=embeddings
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info("Adding aliases")
|
||||||
|
_add_aliases(
|
||||||
|
kb,
|
||||||
|
title_to_id=filtered_title_to_id,
|
||||||
|
max_entities_per_alias=max_entities_per_alias,
|
||||||
|
min_occ=min_occ,
|
||||||
|
prior_prob_input=prior_prob_input,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info("KB size: {} entities, {} aliases".format(
|
||||||
|
kb.get_size_entities(),
|
||||||
|
kb.get_size_aliases()))
|
||||||
|
|
||||||
|
logger.info("Done with kb")
|
||||||
|
return kb
|
||||||
|
|
||||||
|
|
||||||
|
def get_filtered_entities(title_to_id, id_to_descr, entity_frequencies,
|
||||||
|
min_entity_freq: int = 10):
|
||||||
filtered_title_to_id = dict()
|
filtered_title_to_id = dict()
|
||||||
entity_list = []
|
entity_list = []
|
||||||
description_list = []
|
description_list = []
|
||||||
|
@ -72,58 +101,7 @@ def create_kb(
|
||||||
description_list.append(desc)
|
description_list.append(desc)
|
||||||
frequency_list.append(freq)
|
frequency_list.append(freq)
|
||||||
filtered_title_to_id[title] = entity
|
filtered_title_to_id[title] = entity
|
||||||
|
return filtered_title_to_id, entity_list, description_list, frequency_list
|
||||||
print(len(title_to_id.keys()), "original titles")
|
|
||||||
kept_nr = len(filtered_title_to_id.keys())
|
|
||||||
print("kept", kept_nr, "entities with min. frequency", min_entity_freq)
|
|
||||||
|
|
||||||
print()
|
|
||||||
print(now(), " * train entity encoder:")
|
|
||||||
print()
|
|
||||||
encoder = EntityEncoder(nlp, input_dim, entity_vector_length)
|
|
||||||
encoder.train(description_list=description_list, to_print=True)
|
|
||||||
|
|
||||||
print()
|
|
||||||
print(now(), " * get entity embeddings:")
|
|
||||||
print()
|
|
||||||
embeddings = encoder.apply_encoder(description_list)
|
|
||||||
|
|
||||||
print(now(), " * adding", len(entity_list), "entities")
|
|
||||||
kb.set_entities(
|
|
||||||
entity_list=entity_list, freq_list=frequency_list, vector_list=embeddings
|
|
||||||
)
|
|
||||||
|
|
||||||
alias_cnt = _add_aliases(
|
|
||||||
kb,
|
|
||||||
title_to_id=filtered_title_to_id,
|
|
||||||
max_entities_per_alias=max_entities_per_alias,
|
|
||||||
min_occ=min_occ,
|
|
||||||
prior_prob_input=prior_prob_input,
|
|
||||||
)
|
|
||||||
print()
|
|
||||||
print(now(), " * adding", alias_cnt, "aliases")
|
|
||||||
print()
|
|
||||||
|
|
||||||
print()
|
|
||||||
print("# of entities in kb:", kb.get_size_entities())
|
|
||||||
print("# of aliases in kb:", kb.get_size_aliases())
|
|
||||||
|
|
||||||
print(now(), "Done with kb")
|
|
||||||
return kb
|
|
||||||
|
|
||||||
|
|
||||||
def _write_entity_files(
|
|
||||||
entity_def_output, entity_descr_output, title_to_id, id_to_descr
|
|
||||||
):
|
|
||||||
with entity_def_output.open("w", encoding="utf8") as id_file:
|
|
||||||
id_file.write("WP_title" + "|" + "WD_id" + "\n")
|
|
||||||
for title, qid in title_to_id.items():
|
|
||||||
id_file.write(title + "|" + str(qid) + "\n")
|
|
||||||
|
|
||||||
with entity_descr_output.open("w", encoding="utf8") as descr_file:
|
|
||||||
descr_file.write("WD_id" + "|" + "description" + "\n")
|
|
||||||
for qid, descr in id_to_descr.items():
|
|
||||||
descr_file.write(str(qid) + "|" + descr + "\n")
|
|
||||||
|
|
||||||
|
|
||||||
def get_entity_to_id(entity_def_output):
|
def get_entity_to_id(entity_def_output):
|
||||||
|
@ -137,9 +115,9 @@ def get_entity_to_id(entity_def_output):
|
||||||
return entity_to_id
|
return entity_to_id
|
||||||
|
|
||||||
|
|
||||||
def get_id_to_description(entity_descr_output):
|
def get_id_to_description(entity_descr_path):
|
||||||
id_to_desc = dict()
|
id_to_desc = dict()
|
||||||
with entity_descr_output.open("r", encoding="utf8") as csvfile:
|
with entity_descr_path.open("r", encoding="utf8") as csvfile:
|
||||||
csvreader = csv.reader(csvfile, delimiter="|")
|
csvreader = csv.reader(csvfile, delimiter="|")
|
||||||
# skip header
|
# skip header
|
||||||
next(csvreader)
|
next(csvreader)
|
||||||
|
@ -150,7 +128,6 @@ def get_id_to_description(entity_descr_output):
|
||||||
|
|
||||||
def _add_aliases(kb, title_to_id, max_entities_per_alias, min_occ, prior_prob_input):
|
def _add_aliases(kb, title_to_id, max_entities_per_alias, min_occ, prior_prob_input):
|
||||||
wp_titles = title_to_id.keys()
|
wp_titles = title_to_id.keys()
|
||||||
cnt = 0
|
|
||||||
|
|
||||||
# adding aliases with prior probabilities
|
# adding aliases with prior probabilities
|
||||||
# we can read this file sequentially, it's sorted by alias, and then by count
|
# we can read this file sequentially, it's sorted by alias, and then by count
|
||||||
|
@ -187,9 +164,8 @@ def _add_aliases(kb, title_to_id, max_entities_per_alias, min_occ, prior_prob_in
|
||||||
entities=selected_entities,
|
entities=selected_entities,
|
||||||
probabilities=prior_probs,
|
probabilities=prior_probs,
|
||||||
)
|
)
|
||||||
cnt += 1
|
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
print(e)
|
logger.error(e)
|
||||||
total_count = 0
|
total_count = 0
|
||||||
counts = []
|
counts = []
|
||||||
entities = []
|
entities = []
|
||||||
|
@ -202,8 +178,12 @@ def _add_aliases(kb, title_to_id, max_entities_per_alias, min_occ, prior_prob_in
|
||||||
previous_alias = new_alias
|
previous_alias = new_alias
|
||||||
|
|
||||||
line = prior_file.readline()
|
line = prior_file.readline()
|
||||||
return cnt
|
|
||||||
|
|
||||||
|
|
||||||
def now():
|
def read_nlp_kb(model_dir, kb_file):
|
||||||
return datetime.datetime.now()
|
nlp = spacy.load(model_dir)
|
||||||
|
kb = KnowledgeBase(vocab=nlp.vocab)
|
||||||
|
kb.load_bulk(kb_file)
|
||||||
|
logger.info("kb entities: {}".format(kb.get_size_entities()))
|
||||||
|
logger.info("kb aliases: {}".format(kb.get_size_aliases()))
|
||||||
|
return nlp, kb
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
# coding: utf-8
|
# coding: utf-8
|
||||||
from random import shuffle
|
from random import shuffle
|
||||||
|
|
||||||
|
import logging
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from spacy._ml import zero_init, create_default_optimizer
|
from spacy._ml import zero_init, create_default_optimizer
|
||||||
|
@ -10,6 +11,8 @@ from thinc.v2v import Model
|
||||||
from thinc.api import chain
|
from thinc.api import chain
|
||||||
from thinc.neural._classes.affine import Affine
|
from thinc.neural._classes.affine import Affine
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class EntityEncoder:
|
class EntityEncoder:
|
||||||
"""
|
"""
|
||||||
|
@ -50,21 +53,19 @@ class EntityEncoder:
|
||||||
|
|
||||||
start = start + batch_size
|
start = start + batch_size
|
||||||
stop = min(stop + batch_size, len(description_list))
|
stop = min(stop + batch_size, len(description_list))
|
||||||
print("encoded:", stop, "entities")
|
logger.info("encoded: {} entities".format(stop))
|
||||||
|
|
||||||
return encodings
|
return encodings
|
||||||
|
|
||||||
def train(self, description_list, to_print=False):
|
def train(self, description_list, to_print=False):
|
||||||
processed, loss = self._train_model(description_list)
|
processed, loss = self._train_model(description_list)
|
||||||
if to_print:
|
if to_print:
|
||||||
print(
|
logger.info(
|
||||||
"Trained entity descriptions on",
|
"Trained entity descriptions on {} ".format(processed) +
|
||||||
processed,
|
"(non-unique) entities across {} ".format(self.epochs) +
|
||||||
"(non-unique) entities across",
|
"epochs"
|
||||||
self.epochs,
|
|
||||||
"epochs",
|
|
||||||
)
|
)
|
||||||
print("Final loss:", loss)
|
logger.info("Final loss: {}".format(loss))
|
||||||
|
|
||||||
def _train_model(self, description_list):
|
def _train_model(self, description_list):
|
||||||
best_loss = 1.0
|
best_loss = 1.0
|
||||||
|
@ -93,7 +94,7 @@ class EntityEncoder:
|
||||||
|
|
||||||
loss = self._update(batch)
|
loss = self._update(batch)
|
||||||
if batch_nr % 25 == 0:
|
if batch_nr % 25 == 0:
|
||||||
print("loss:", loss)
|
logger.info("loss: {} ".format(loss))
|
||||||
processed += len(batch)
|
processed += len(batch)
|
||||||
|
|
||||||
# in general, continue training if we haven't reached our ideal min yet
|
# in general, continue training if we haven't reached our ideal min yet
|
||||||
|
|
|
@ -1,10 +1,13 @@
|
||||||
# coding: utf-8
|
# coding: utf-8
|
||||||
from __future__ import unicode_literals
|
from __future__ import unicode_literals
|
||||||
|
|
||||||
|
import logging
|
||||||
import random
|
import random
|
||||||
import re
|
import re
|
||||||
import bz2
|
import bz2
|
||||||
import datetime
|
import json
|
||||||
|
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
from spacy.gold import GoldParse
|
from spacy.gold import GoldParse
|
||||||
from bin.wiki_entity_linking import kb_creator
|
from bin.wiki_entity_linking import kb_creator
|
||||||
|
@ -15,18 +18,30 @@ Gold-standard entities are stored in one file in standoff format (by character o
|
||||||
"""
|
"""
|
||||||
|
|
||||||
ENTITY_FILE = "gold_entities.csv"
|
ENTITY_FILE = "gold_entities.csv"
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def now():
|
def create_training_examples_and_descriptions(wikipedia_input,
|
||||||
return datetime.datetime.now()
|
entity_def_input,
|
||||||
|
description_output,
|
||||||
|
training_output,
|
||||||
def create_training(wikipedia_input, entity_def_input, training_output, limit=None):
|
parse_descriptions,
|
||||||
|
limit=None):
|
||||||
wp_to_id = kb_creator.get_entity_to_id(entity_def_input)
|
wp_to_id = kb_creator.get_entity_to_id(entity_def_input)
|
||||||
_process_wikipedia_texts(wikipedia_input, wp_to_id, training_output, limit=limit)
|
_process_wikipedia_texts(wikipedia_input,
|
||||||
|
wp_to_id,
|
||||||
|
description_output,
|
||||||
|
training_output,
|
||||||
|
parse_descriptions,
|
||||||
|
limit)
|
||||||
|
|
||||||
|
|
||||||
def _process_wikipedia_texts(wikipedia_input, wp_to_id, training_output, limit=None):
|
def _process_wikipedia_texts(wikipedia_input,
|
||||||
|
wp_to_id,
|
||||||
|
output,
|
||||||
|
training_output,
|
||||||
|
parse_descriptions,
|
||||||
|
limit=None):
|
||||||
"""
|
"""
|
||||||
Read the XML wikipedia data to parse out training data:
|
Read the XML wikipedia data to parse out training data:
|
||||||
raw text data + positive instances
|
raw text data + positive instances
|
||||||
|
@ -35,29 +50,21 @@ def _process_wikipedia_texts(wikipedia_input, wp_to_id, training_output, limit=N
|
||||||
id_regex = re.compile(r"(?<=<id>)\d*(?=</id>)")
|
id_regex = re.compile(r"(?<=<id>)\d*(?=</id>)")
|
||||||
|
|
||||||
read_ids = set()
|
read_ids = set()
|
||||||
entityfile_loc = training_output / ENTITY_FILE
|
|
||||||
with entityfile_loc.open("w", encoding="utf8") as entityfile:
|
|
||||||
# write entity training header file
|
|
||||||
_write_training_entity(
|
|
||||||
outputfile=entityfile,
|
|
||||||
article_id="article_id",
|
|
||||||
alias="alias",
|
|
||||||
entity="WD_id",
|
|
||||||
start="start",
|
|
||||||
end="end",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
with output.open("a", encoding="utf8") as descr_file, training_output.open("w", encoding="utf8") as entity_file:
|
||||||
|
if parse_descriptions:
|
||||||
|
_write_training_description(descr_file, "WD_id", "description")
|
||||||
with bz2.open(wikipedia_input, mode="rb") as file:
|
with bz2.open(wikipedia_input, mode="rb") as file:
|
||||||
line = file.readline()
|
article_count = 0
|
||||||
cnt = 0
|
|
||||||
article_text = ""
|
article_text = ""
|
||||||
article_title = None
|
article_title = None
|
||||||
article_id = None
|
article_id = None
|
||||||
reading_text = False
|
reading_text = False
|
||||||
reading_revision = False
|
reading_revision = False
|
||||||
while line and (not limit or cnt < limit):
|
|
||||||
if cnt % 1000000 == 0:
|
logger.info("Processed {} articles".format(article_count))
|
||||||
print(now(), "processed", cnt, "lines of Wikipedia dump")
|
|
||||||
|
for line in file:
|
||||||
clean_line = line.strip().decode("utf-8")
|
clean_line = line.strip().decode("utf-8")
|
||||||
|
|
||||||
if clean_line == "<revision>":
|
if clean_line == "<revision>":
|
||||||
|
@ -70,28 +77,32 @@ def _process_wikipedia_texts(wikipedia_input, wp_to_id, training_output, limit=N
|
||||||
article_text = ""
|
article_text = ""
|
||||||
article_title = None
|
article_title = None
|
||||||
article_id = None
|
article_id = None
|
||||||
|
|
||||||
# finished reading this page
|
# finished reading this page
|
||||||
elif clean_line == "</page>":
|
elif clean_line == "</page>":
|
||||||
if article_id:
|
if article_id:
|
||||||
try:
|
clean_text, entities = _process_wp_text(
|
||||||
_process_wp_text(
|
|
||||||
wp_to_id,
|
|
||||||
entityfile,
|
|
||||||
article_id,
|
|
||||||
article_title,
|
|
||||||
article_text.strip(),
|
|
||||||
training_output,
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
print(
|
|
||||||
"Error processing article", article_id, article_title, e
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
print(
|
|
||||||
"Done processing a page, but couldn't find an article_id ?",
|
|
||||||
article_title,
|
article_title,
|
||||||
|
article_text,
|
||||||
|
wp_to_id
|
||||||
)
|
)
|
||||||
|
if clean_text is not None and entities is not None:
|
||||||
|
_write_training_entities(entity_file,
|
||||||
|
article_id,
|
||||||
|
clean_text,
|
||||||
|
entities)
|
||||||
|
|
||||||
|
if article_title in wp_to_id and parse_descriptions:
|
||||||
|
description = " ".join(clean_text[:1000].split(" ")[:-1])
|
||||||
|
_write_training_description(
|
||||||
|
descr_file,
|
||||||
|
wp_to_id[article_title],
|
||||||
|
description
|
||||||
|
)
|
||||||
|
article_count += 1
|
||||||
|
if article_count % 10000 == 0:
|
||||||
|
logger.info("Processed {} articles".format(article_count))
|
||||||
|
if limit and article_count >= limit:
|
||||||
|
break
|
||||||
article_text = ""
|
article_text = ""
|
||||||
article_title = None
|
article_title = None
|
||||||
article_id = None
|
article_id = None
|
||||||
|
@ -115,7 +126,7 @@ def _process_wikipedia_texts(wikipedia_input, wp_to_id, training_output, limit=N
|
||||||
if ids:
|
if ids:
|
||||||
article_id = ids[0]
|
article_id = ids[0]
|
||||||
if article_id in read_ids:
|
if article_id in read_ids:
|
||||||
print(
|
logger.info(
|
||||||
"Found duplicate article ID", article_id, clean_line
|
"Found duplicate article ID", article_id, clean_line
|
||||||
) # This should never happen ...
|
) # This should never happen ...
|
||||||
read_ids.add(article_id)
|
read_ids.add(article_id)
|
||||||
|
@ -125,115 +136,10 @@ def _process_wikipedia_texts(wikipedia_input, wp_to_id, training_output, limit=N
|
||||||
titles = title_regex.search(clean_line)
|
titles = title_regex.search(clean_line)
|
||||||
if titles:
|
if titles:
|
||||||
article_title = titles[0].strip()
|
article_title = titles[0].strip()
|
||||||
|
logger.info("Finished. Processed {} articles".format(article_count))
|
||||||
line = file.readline()
|
|
||||||
cnt += 1
|
|
||||||
print(now(), "processed", cnt, "lines of Wikipedia dump")
|
|
||||||
|
|
||||||
|
|
||||||
text_regex = re.compile(r"(?<=<text xml:space=\"preserve\">).*(?=</text)")
|
text_regex = re.compile(r"(?<=<text xml:space=\"preserve\">).*(?=</text)")
|
||||||
|
|
||||||
|
|
||||||
def _process_wp_text(
|
|
||||||
wp_to_id, entityfile, article_id, article_title, article_text, training_output
|
|
||||||
):
|
|
||||||
found_entities = False
|
|
||||||
|
|
||||||
# ignore meta Wikipedia pages
|
|
||||||
if article_title.startswith("Wikipedia:"):
|
|
||||||
return
|
|
||||||
|
|
||||||
# remove the text tags
|
|
||||||
text = text_regex.search(article_text).group(0)
|
|
||||||
|
|
||||||
# stop processing if this is a redirect page
|
|
||||||
if text.startswith("#REDIRECT"):
|
|
||||||
return
|
|
||||||
|
|
||||||
# get the raw text without markup etc, keeping only interwiki links
|
|
||||||
clean_text = _get_clean_wp_text(text)
|
|
||||||
|
|
||||||
# read the text char by char to get the right offsets for the interwiki links
|
|
||||||
final_text = ""
|
|
||||||
open_read = 0
|
|
||||||
reading_text = True
|
|
||||||
reading_entity = False
|
|
||||||
reading_mention = False
|
|
||||||
reading_special_case = False
|
|
||||||
entity_buffer = ""
|
|
||||||
mention_buffer = ""
|
|
||||||
for index, letter in enumerate(clean_text):
|
|
||||||
if letter == "[":
|
|
||||||
open_read += 1
|
|
||||||
elif letter == "]":
|
|
||||||
open_read -= 1
|
|
||||||
elif letter == "|":
|
|
||||||
if reading_text:
|
|
||||||
final_text += letter
|
|
||||||
# switch from reading entity to mention in the [[entity|mention]] pattern
|
|
||||||
elif reading_entity:
|
|
||||||
reading_text = False
|
|
||||||
reading_entity = False
|
|
||||||
reading_mention = True
|
|
||||||
else:
|
|
||||||
reading_special_case = True
|
|
||||||
else:
|
|
||||||
if reading_entity:
|
|
||||||
entity_buffer += letter
|
|
||||||
elif reading_mention:
|
|
||||||
mention_buffer += letter
|
|
||||||
elif reading_text:
|
|
||||||
final_text += letter
|
|
||||||
else:
|
|
||||||
raise ValueError("Not sure at point", clean_text[index - 2 : index + 2])
|
|
||||||
|
|
||||||
if open_read > 2:
|
|
||||||
reading_special_case = True
|
|
||||||
|
|
||||||
if open_read == 2 and reading_text:
|
|
||||||
reading_text = False
|
|
||||||
reading_entity = True
|
|
||||||
reading_mention = False
|
|
||||||
|
|
||||||
# we just finished reading an entity
|
|
||||||
if open_read == 0 and not reading_text:
|
|
||||||
if "#" in entity_buffer or entity_buffer.startswith(":"):
|
|
||||||
reading_special_case = True
|
|
||||||
# Ignore cases with nested structures like File: handles etc
|
|
||||||
if not reading_special_case:
|
|
||||||
if not mention_buffer:
|
|
||||||
mention_buffer = entity_buffer
|
|
||||||
start = len(final_text)
|
|
||||||
end = start + len(mention_buffer)
|
|
||||||
qid = wp_to_id.get(entity_buffer, None)
|
|
||||||
if qid:
|
|
||||||
_write_training_entity(
|
|
||||||
outputfile=entityfile,
|
|
||||||
article_id=article_id,
|
|
||||||
alias=mention_buffer,
|
|
||||||
entity=qid,
|
|
||||||
start=start,
|
|
||||||
end=end,
|
|
||||||
)
|
|
||||||
found_entities = True
|
|
||||||
final_text += mention_buffer
|
|
||||||
|
|
||||||
entity_buffer = ""
|
|
||||||
mention_buffer = ""
|
|
||||||
|
|
||||||
reading_text = True
|
|
||||||
reading_entity = False
|
|
||||||
reading_mention = False
|
|
||||||
reading_special_case = False
|
|
||||||
|
|
||||||
if found_entities:
|
|
||||||
_write_training_article(
|
|
||||||
article_id=article_id,
|
|
||||||
clean_text=final_text,
|
|
||||||
training_output=training_output,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
info_regex = re.compile(r"{[^{]*?}")
|
info_regex = re.compile(r"{[^{]*?}")
|
||||||
htlm_regex = re.compile(r"<!--[^-]*-->")
|
htlm_regex = re.compile(r"<!--[^-]*-->")
|
||||||
category_regex = re.compile(r"\[\[Category:[^\[]*]]")
|
category_regex = re.compile(r"\[\[Category:[^\[]*]]")
|
||||||
|
@ -242,6 +148,29 @@ ref_regex = re.compile(r"<ref.*?>") # non-greedy
|
||||||
ref_2_regex = re.compile(r"</ref.*?>") # non-greedy
|
ref_2_regex = re.compile(r"</ref.*?>") # non-greedy
|
||||||
|
|
||||||
|
|
||||||
|
def _process_wp_text(article_title, article_text, wp_to_id):
|
||||||
|
# ignore meta Wikipedia pages
|
||||||
|
if (
|
||||||
|
article_title.startswith("Wikipedia:") or
|
||||||
|
article_title.startswith("Kategori:")
|
||||||
|
):
|
||||||
|
return None, None
|
||||||
|
|
||||||
|
# remove the text tags
|
||||||
|
text_search = text_regex.search(article_text)
|
||||||
|
if text_search is None:
|
||||||
|
return None, None
|
||||||
|
text = text_search.group(0)
|
||||||
|
|
||||||
|
# stop processing if this is a redirect page
|
||||||
|
if text.startswith("#REDIRECT"):
|
||||||
|
return None, None
|
||||||
|
|
||||||
|
# get the raw text without markup etc, keeping only interwiki links
|
||||||
|
clean_text, entities = _remove_links(_get_clean_wp_text(text), wp_to_id)
|
||||||
|
return clean_text, entities
|
||||||
|
|
||||||
|
|
||||||
def _get_clean_wp_text(article_text):
|
def _get_clean_wp_text(article_text):
|
||||||
clean_text = article_text.strip()
|
clean_text = article_text.strip()
|
||||||
|
|
||||||
|
@ -300,130 +229,167 @@ def _get_clean_wp_text(article_text):
|
||||||
return clean_text.strip()
|
return clean_text.strip()
|
||||||
|
|
||||||
|
|
||||||
def _write_training_article(article_id, clean_text, training_output):
|
def _remove_links(clean_text, wp_to_id):
|
||||||
file_loc = training_output / "{}.txt".format(article_id)
|
# read the text char by char to get the right offsets for the interwiki links
|
||||||
with file_loc.open("w", encoding="utf8") as outputfile:
|
entities = []
|
||||||
outputfile.write(clean_text)
|
final_text = ""
|
||||||
|
open_read = 0
|
||||||
|
reading_text = True
|
||||||
|
reading_entity = False
|
||||||
|
reading_mention = False
|
||||||
|
reading_special_case = False
|
||||||
|
entity_buffer = ""
|
||||||
|
mention_buffer = ""
|
||||||
|
for index, letter in enumerate(clean_text):
|
||||||
|
if letter == "[":
|
||||||
|
open_read += 1
|
||||||
|
elif letter == "]":
|
||||||
|
open_read -= 1
|
||||||
|
elif letter == "|":
|
||||||
|
if reading_text:
|
||||||
|
final_text += letter
|
||||||
|
# switch from reading entity to mention in the [[entity|mention]] pattern
|
||||||
|
elif reading_entity:
|
||||||
|
reading_text = False
|
||||||
|
reading_entity = False
|
||||||
|
reading_mention = True
|
||||||
|
else:
|
||||||
|
reading_special_case = True
|
||||||
|
else:
|
||||||
|
if reading_entity:
|
||||||
|
entity_buffer += letter
|
||||||
|
elif reading_mention:
|
||||||
|
mention_buffer += letter
|
||||||
|
elif reading_text:
|
||||||
|
final_text += letter
|
||||||
|
else:
|
||||||
|
raise ValueError("Not sure at point", clean_text[index - 2: index + 2])
|
||||||
|
|
||||||
|
if open_read > 2:
|
||||||
|
reading_special_case = True
|
||||||
|
|
||||||
|
if open_read == 2 and reading_text:
|
||||||
|
reading_text = False
|
||||||
|
reading_entity = True
|
||||||
|
reading_mention = False
|
||||||
|
|
||||||
|
# we just finished reading an entity
|
||||||
|
if open_read == 0 and not reading_text:
|
||||||
|
if "#" in entity_buffer or entity_buffer.startswith(":"):
|
||||||
|
reading_special_case = True
|
||||||
|
# Ignore cases with nested structures like File: handles etc
|
||||||
|
if not reading_special_case:
|
||||||
|
if not mention_buffer:
|
||||||
|
mention_buffer = entity_buffer
|
||||||
|
start = len(final_text)
|
||||||
|
end = start + len(mention_buffer)
|
||||||
|
qid = wp_to_id.get(entity_buffer, None)
|
||||||
|
if qid:
|
||||||
|
entities.append((mention_buffer, qid, start, end))
|
||||||
|
final_text += mention_buffer
|
||||||
|
|
||||||
|
entity_buffer = ""
|
||||||
|
mention_buffer = ""
|
||||||
|
|
||||||
|
reading_text = True
|
||||||
|
reading_entity = False
|
||||||
|
reading_mention = False
|
||||||
|
reading_special_case = False
|
||||||
|
return final_text, entities
|
||||||
|
|
||||||
|
|
||||||
def _write_training_entity(outputfile, article_id, alias, entity, start, end):
|
def _write_training_description(outputfile, qid, description):
|
||||||
line = "{}|{}|{}|{}|{}\n".format(article_id, alias, entity, start, end)
|
if description is not None:
|
||||||
|
line = str(qid) + "|" + description + "\n"
|
||||||
|
outputfile.write(line)
|
||||||
|
|
||||||
|
|
||||||
|
def _write_training_entities(outputfile, article_id, clean_text, entities):
|
||||||
|
entities_data = [{"alias": ent[0], "entity": ent[1], "start": ent[2], "end": ent[3]} for ent in entities]
|
||||||
|
line = json.dumps(
|
||||||
|
{
|
||||||
|
"article_id": article_id,
|
||||||
|
"clean_text": clean_text,
|
||||||
|
"entities": entities_data
|
||||||
|
},
|
||||||
|
ensure_ascii=False) + "\n"
|
||||||
outputfile.write(line)
|
outputfile.write(line)
|
||||||
|
|
||||||
|
|
||||||
|
def read_training(nlp, entity_file_path, dev, limit, kb):
|
||||||
|
""" This method provides training examples that correspond to the entity annotations found by the nlp object.
|
||||||
|
For training,, it will include negative training examples by using the candidate generator,
|
||||||
|
and it will only keep positive training examples that can be found by using the candidate generator.
|
||||||
|
For testing, it will include all positive examples only."""
|
||||||
|
|
||||||
|
from tqdm import tqdm
|
||||||
|
data = []
|
||||||
|
num_entities = 0
|
||||||
|
get_gold_parse = partial(_get_gold_parse, dev=dev, kb=kb)
|
||||||
|
|
||||||
|
logger.info("Reading {} data with limit {}".format('dev' if dev else 'train', limit))
|
||||||
|
with entity_file_path.open("r", encoding="utf8") as file:
|
||||||
|
with tqdm(total=limit, leave=False) as pbar:
|
||||||
|
for i, line in enumerate(file):
|
||||||
|
example = json.loads(line)
|
||||||
|
article_id = example["article_id"]
|
||||||
|
clean_text = example["clean_text"]
|
||||||
|
entities = example["entities"]
|
||||||
|
|
||||||
|
if dev != is_dev(article_id) or len(clean_text) >= 30000:
|
||||||
|
continue
|
||||||
|
|
||||||
|
doc = nlp(clean_text)
|
||||||
|
gold = get_gold_parse(doc, entities)
|
||||||
|
if gold and len(gold.links) > 0:
|
||||||
|
data.append((doc, gold))
|
||||||
|
num_entities += len(gold.links)
|
||||||
|
pbar.update(len(gold.links))
|
||||||
|
if limit and num_entities >= limit:
|
||||||
|
break
|
||||||
|
logger.info("Read {} entities in {} articles".format(num_entities, len(data)))
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
def _get_gold_parse(doc, entities, dev, kb):
|
||||||
|
gold_entities = {}
|
||||||
|
tagged_ent_positions = set(
|
||||||
|
[(ent.start_char, ent.end_char) for ent in doc.ents]
|
||||||
|
)
|
||||||
|
|
||||||
|
for entity in entities:
|
||||||
|
entity_id = entity["entity"]
|
||||||
|
alias = entity["alias"]
|
||||||
|
start = entity["start"]
|
||||||
|
end = entity["end"]
|
||||||
|
|
||||||
|
candidates = kb.get_candidates(alias)
|
||||||
|
candidate_ids = [
|
||||||
|
c.entity_ for c in candidates
|
||||||
|
]
|
||||||
|
|
||||||
|
should_add_ent = (
|
||||||
|
dev or
|
||||||
|
(
|
||||||
|
(start, end) in tagged_ent_positions and
|
||||||
|
entity_id in candidate_ids and
|
||||||
|
len(candidates) > 1
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
if should_add_ent:
|
||||||
|
value_by_id = {entity_id: 1.0}
|
||||||
|
if not dev:
|
||||||
|
random.shuffle(candidate_ids)
|
||||||
|
value_by_id.update({
|
||||||
|
kb_id: 0.0
|
||||||
|
for kb_id in candidate_ids
|
||||||
|
if kb_id != entity_id
|
||||||
|
})
|
||||||
|
gold_entities[(start, end)] = value_by_id
|
||||||
|
|
||||||
|
return GoldParse(doc, links=gold_entities)
|
||||||
|
|
||||||
|
|
||||||
def is_dev(article_id):
|
def is_dev(article_id):
|
||||||
return article_id.endswith("3")
|
return article_id.endswith("3")
|
||||||
|
|
||||||
|
|
||||||
def read_training(nlp, training_dir, dev, limit, kb=None):
|
|
||||||
""" This method provides training examples that correspond to the entity annotations found by the nlp object.
|
|
||||||
When kb is provided (for training), it will include negative training examples by using the candidate generator,
|
|
||||||
and it will only keep positive training examples that can be found in the KB.
|
|
||||||
When kb=None (for testing), it will include all positive examples only."""
|
|
||||||
entityfile_loc = training_dir / ENTITY_FILE
|
|
||||||
data = []
|
|
||||||
|
|
||||||
# assume the data is written sequentially, so we can reuse the article docs
|
|
||||||
current_article_id = None
|
|
||||||
current_doc = None
|
|
||||||
ents_by_offset = dict()
|
|
||||||
skip_articles = set()
|
|
||||||
total_entities = 0
|
|
||||||
|
|
||||||
with entityfile_loc.open("r", encoding="utf8") as file:
|
|
||||||
for line in file:
|
|
||||||
if not limit or len(data) < limit:
|
|
||||||
fields = line.replace("\n", "").split(sep="|")
|
|
||||||
article_id = fields[0]
|
|
||||||
alias = fields[1]
|
|
||||||
wd_id = fields[2]
|
|
||||||
start = fields[3]
|
|
||||||
end = fields[4]
|
|
||||||
|
|
||||||
if (
|
|
||||||
dev == is_dev(article_id)
|
|
||||||
and article_id != "article_id"
|
|
||||||
and article_id not in skip_articles
|
|
||||||
):
|
|
||||||
if not current_doc or (current_article_id != article_id):
|
|
||||||
# parse the new article text
|
|
||||||
file_name = article_id + ".txt"
|
|
||||||
try:
|
|
||||||
training_file = training_dir / file_name
|
|
||||||
with training_file.open("r", encoding="utf8") as f:
|
|
||||||
text = f.read()
|
|
||||||
# threshold for convenience / speed of processing
|
|
||||||
if len(text) < 30000:
|
|
||||||
current_doc = nlp(text)
|
|
||||||
current_article_id = article_id
|
|
||||||
ents_by_offset = dict()
|
|
||||||
for ent in current_doc.ents:
|
|
||||||
sent_length = len(ent.sent)
|
|
||||||
# custom filtering to avoid too long or too short sentences
|
|
||||||
if 5 < sent_length < 100:
|
|
||||||
offset = "{}_{}".format(
|
|
||||||
ent.start_char, ent.end_char
|
|
||||||
)
|
|
||||||
ents_by_offset[offset] = ent
|
|
||||||
else:
|
|
||||||
skip_articles.add(article_id)
|
|
||||||
current_doc = None
|
|
||||||
except Exception as e:
|
|
||||||
print("Problem parsing article", article_id, e)
|
|
||||||
skip_articles.add(article_id)
|
|
||||||
|
|
||||||
# repeat checking this condition in case an exception was thrown
|
|
||||||
if current_doc and (current_article_id == article_id):
|
|
||||||
offset = "{}_{}".format(start, end)
|
|
||||||
found_ent = ents_by_offset.get(offset, None)
|
|
||||||
if found_ent:
|
|
||||||
if found_ent.text != alias:
|
|
||||||
skip_articles.add(article_id)
|
|
||||||
current_doc = None
|
|
||||||
else:
|
|
||||||
sent = found_ent.sent.as_doc()
|
|
||||||
|
|
||||||
gold_start = int(start) - found_ent.sent.start_char
|
|
||||||
gold_end = int(end) - found_ent.sent.start_char
|
|
||||||
|
|
||||||
gold_entities = {}
|
|
||||||
found_useful = False
|
|
||||||
for ent in sent.ents:
|
|
||||||
entry = (ent.start_char, ent.end_char)
|
|
||||||
gold_entry = (gold_start, gold_end)
|
|
||||||
if entry == gold_entry:
|
|
||||||
# add both pos and neg examples (in random order)
|
|
||||||
# this will exclude examples not in the KB
|
|
||||||
if kb:
|
|
||||||
value_by_id = {}
|
|
||||||
candidates = kb.get_candidates(alias)
|
|
||||||
candidate_ids = [
|
|
||||||
c.entity_ for c in candidates
|
|
||||||
]
|
|
||||||
random.shuffle(candidate_ids)
|
|
||||||
for kb_id in candidate_ids:
|
|
||||||
found_useful = True
|
|
||||||
if kb_id != wd_id:
|
|
||||||
value_by_id[kb_id] = 0.0
|
|
||||||
else:
|
|
||||||
value_by_id[kb_id] = 1.0
|
|
||||||
gold_entities[entry] = value_by_id
|
|
||||||
# if no KB, keep all positive examples
|
|
||||||
else:
|
|
||||||
found_useful = True
|
|
||||||
value_by_id = {wd_id: 1.0}
|
|
||||||
|
|
||||||
gold_entities[entry] = value_by_id
|
|
||||||
# currently feeding the gold data one entity per sentence at a time
|
|
||||||
# setting all other entities to empty gold dictionary
|
|
||||||
else:
|
|
||||||
gold_entities[entry] = {}
|
|
||||||
if found_useful:
|
|
||||||
gold = GoldParse(doc=sent, links=gold_entities)
|
|
||||||
data.append((sent, gold))
|
|
||||||
total_entities += 1
|
|
||||||
if len(data) % 2500 == 0:
|
|
||||||
print(" -read", total_entities, "entities")
|
|
||||||
|
|
||||||
print(" -read", total_entities, "entities")
|
|
||||||
return data
|
|
||||||
|
|
|
@ -13,27 +13,25 @@ from https://dumps.wikimedia.org/enwiki/latest/
|
||||||
"""
|
"""
|
||||||
from __future__ import unicode_literals
|
from __future__ import unicode_literals
|
||||||
|
|
||||||
import datetime
|
import logging
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import plac
|
import plac
|
||||||
|
|
||||||
from bin.wiki_entity_linking import wikipedia_processor as wp
|
from bin.wiki_entity_linking import wikipedia_processor as wp, wikidata_processor as wd
|
||||||
from bin.wiki_entity_linking import kb_creator
|
from bin.wiki_entity_linking import kb_creator
|
||||||
|
from bin.wiki_entity_linking import training_set_creator
|
||||||
|
from bin.wiki_entity_linking import TRAINING_DATA_FILE, KB_FILE, ENTITY_DESCR_PATH, KB_MODEL_DIR, LOG_FORMAT
|
||||||
|
from bin.wiki_entity_linking import ENTITY_FREQ_PATH, PRIOR_PROB_PATH, ENTITY_DEFS_PATH
|
||||||
import spacy
|
import spacy
|
||||||
|
|
||||||
from spacy import Errors
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def now():
|
|
||||||
return datetime.datetime.now()
|
|
||||||
|
|
||||||
|
|
||||||
@plac.annotations(
|
@plac.annotations(
|
||||||
wd_json=("Path to the downloaded WikiData JSON dump.", "positional", None, Path),
|
wd_json=("Path to the downloaded WikiData JSON dump.", "positional", None, Path),
|
||||||
wp_xml=("Path to the downloaded Wikipedia XML dump.", "positional", None, Path),
|
wp_xml=("Path to the downloaded Wikipedia XML dump.", "positional", None, Path),
|
||||||
output_dir=("Output directory", "positional", None, Path),
|
output_dir=("Output directory", "positional", None, Path),
|
||||||
model=("Model name, should include pretrained vectors.", "positional", None, str),
|
model=("Model name or path, should include pretrained vectors.", "positional", None, str),
|
||||||
max_per_alias=("Max. # entities per alias (default 10)", "option", "a", int),
|
max_per_alias=("Max. # entities per alias (default 10)", "option", "a", int),
|
||||||
min_freq=("Min. count of an entity in the corpus (default 20)", "option", "f", int),
|
min_freq=("Min. count of an entity in the corpus (default 20)", "option", "f", int),
|
||||||
min_pair=("Min. count of entity-alias pairs (default 5)", "option", "c", int),
|
min_pair=("Min. count of entity-alias pairs (default 5)", "option", "c", int),
|
||||||
|
@ -41,7 +39,9 @@ def now():
|
||||||
loc_prior_prob=("Location to file with prior probabilities", "option", "p", Path),
|
loc_prior_prob=("Location to file with prior probabilities", "option", "p", Path),
|
||||||
loc_entity_defs=("Location to file with entity definitions", "option", "d", Path),
|
loc_entity_defs=("Location to file with entity definitions", "option", "d", Path),
|
||||||
loc_entity_desc=("Location to file with entity descriptions", "option", "s", Path),
|
loc_entity_desc=("Location to file with entity descriptions", "option", "s", Path),
|
||||||
|
descriptions_from_wikipedia=("Flag for using wp descriptions not wd", "flag", "wp"),
|
||||||
limit=("Optional threshold to limit lines read from dumps", "option", "l", int),
|
limit=("Optional threshold to limit lines read from dumps", "option", "l", int),
|
||||||
|
lang=("Optional language for which to get wikidata titles. Defaults to 'en'", "option", "la", str),
|
||||||
)
|
)
|
||||||
def main(
|
def main(
|
||||||
wd_json,
|
wd_json,
|
||||||
|
@ -55,20 +55,29 @@ def main(
|
||||||
loc_prior_prob=None,
|
loc_prior_prob=None,
|
||||||
loc_entity_defs=None,
|
loc_entity_defs=None,
|
||||||
loc_entity_desc=None,
|
loc_entity_desc=None,
|
||||||
|
descriptions_from_wikipedia=False,
|
||||||
limit=None,
|
limit=None,
|
||||||
|
lang="en",
|
||||||
):
|
):
|
||||||
print(now(), "Creating KB with Wikipedia and WikiData")
|
|
||||||
print()
|
entity_defs_path = loc_entity_defs if loc_entity_defs else output_dir / ENTITY_DEFS_PATH
|
||||||
|
entity_descr_path = loc_entity_desc if loc_entity_desc else output_dir / ENTITY_DESCR_PATH
|
||||||
|
entity_freq_path = output_dir / ENTITY_FREQ_PATH
|
||||||
|
prior_prob_path = loc_prior_prob if loc_prior_prob else output_dir / PRIOR_PROB_PATH
|
||||||
|
training_entities_path = output_dir / TRAINING_DATA_FILE
|
||||||
|
kb_path = output_dir / KB_FILE
|
||||||
|
|
||||||
|
logger.info("Creating KB with Wikipedia and WikiData")
|
||||||
|
|
||||||
if limit is not None:
|
if limit is not None:
|
||||||
print("Warning: reading only", limit, "lines of Wikipedia/Wikidata dumps.")
|
logger.warning("Warning: reading only {} lines of Wikipedia/Wikidata dumps.".format(limit))
|
||||||
|
|
||||||
# STEP 0: set up IO
|
# STEP 0: set up IO
|
||||||
if not output_dir.exists():
|
if not output_dir.exists():
|
||||||
output_dir.mkdir()
|
output_dir.mkdir(parents=True)
|
||||||
|
|
||||||
# STEP 1: create the NLP object
|
# STEP 1: create the NLP object
|
||||||
print(now(), "STEP 1: loaded model", model)
|
logger.info("STEP 1: Loading model {}".format(model))
|
||||||
nlp = spacy.load(model)
|
nlp = spacy.load(model)
|
||||||
|
|
||||||
# check the length of the nlp vectors
|
# check the length of the nlp vectors
|
||||||
|
@ -79,64 +88,68 @@ def main(
|
||||||
)
|
)
|
||||||
|
|
||||||
# STEP 2: create prior probabilities from WP
|
# STEP 2: create prior probabilities from WP
|
||||||
print()
|
if not prior_prob_path.exists():
|
||||||
if loc_prior_prob:
|
|
||||||
print(now(), "STEP 2: reading prior probabilities from", loc_prior_prob)
|
|
||||||
else:
|
|
||||||
# It takes about 2h to process 1000M lines of Wikipedia XML dump
|
# It takes about 2h to process 1000M lines of Wikipedia XML dump
|
||||||
loc_prior_prob = output_dir / "prior_prob.csv"
|
logger.info("STEP 2: writing prior probabilities to {}".format(prior_prob_path))
|
||||||
print(now(), "STEP 2: writing prior probabilities at", loc_prior_prob)
|
wp.read_prior_probs(wp_xml, prior_prob_path, limit=limit)
|
||||||
wp.read_prior_probs(wp_xml, loc_prior_prob, limit=limit)
|
logger.info("STEP 2: reading prior probabilities from {}".format(prior_prob_path))
|
||||||
|
|
||||||
# STEP 3: deduce entity frequencies from WP (takes only a few minutes)
|
# STEP 3: deduce entity frequencies from WP (takes only a few minutes)
|
||||||
print()
|
logger.info("STEP 3: calculating entity frequencies")
|
||||||
print(now(), "STEP 3: calculating entity frequencies")
|
wp.write_entity_counts(prior_prob_path, entity_freq_path, to_print=False)
|
||||||
loc_entity_freq = output_dir / "entity_freq.csv"
|
|
||||||
wp.write_entity_counts(loc_prior_prob, loc_entity_freq, to_print=False)
|
|
||||||
|
|
||||||
loc_kb = output_dir / "kb"
|
# STEP 4: reading definitions and (possibly) descriptions from WikiData or from file
|
||||||
|
message = " and descriptions" if not descriptions_from_wikipedia else ""
|
||||||
# STEP 4: reading entity descriptions and definitions from WikiData or from file
|
if (not entity_defs_path.exists()) or (not descriptions_from_wikipedia and not entity_descr_path.exists()):
|
||||||
print()
|
|
||||||
if loc_entity_defs and loc_entity_desc:
|
|
||||||
read_raw = False
|
|
||||||
print(now(), "STEP 4a: reading entity definitions from", loc_entity_defs)
|
|
||||||
print(now(), "STEP 4b: reading entity descriptions from", loc_entity_desc)
|
|
||||||
else:
|
|
||||||
# It takes about 10h to process 55M lines of Wikidata JSON dump
|
# It takes about 10h to process 55M lines of Wikidata JSON dump
|
||||||
read_raw = True
|
logger.info("STEP 4: parsing wikidata for entity definitions" + message)
|
||||||
loc_entity_defs = output_dir / "entity_defs.csv"
|
title_to_id, id_to_descr = wd.read_wikidata_entities_json(
|
||||||
loc_entity_desc = output_dir / "entity_descriptions.csv"
|
wd_json,
|
||||||
print(now(), "STEP 4: parsing wikidata for entity definitions and descriptions")
|
limit,
|
||||||
|
to_print=False,
|
||||||
|
lang=lang,
|
||||||
|
parse_descriptions=(not descriptions_from_wikipedia),
|
||||||
|
)
|
||||||
|
wd.write_entity_files(entity_defs_path, title_to_id)
|
||||||
|
if not descriptions_from_wikipedia:
|
||||||
|
wd.write_entity_description_files(entity_descr_path, id_to_descr)
|
||||||
|
logger.info("STEP 4: read entity definitions" + message)
|
||||||
|
|
||||||
# STEP 5: creating the actual KB
|
# STEP 5: Getting gold entities from wikipedia
|
||||||
|
message = " and descriptions" if descriptions_from_wikipedia else ""
|
||||||
|
if (not training_entities_path.exists()) or (descriptions_from_wikipedia and not entity_descr_path.exists()):
|
||||||
|
logger.info("STEP 5: parsing wikipedia for gold entities" + message)
|
||||||
|
training_set_creator.create_training_examples_and_descriptions(
|
||||||
|
wp_xml,
|
||||||
|
entity_defs_path,
|
||||||
|
entity_descr_path,
|
||||||
|
training_entities_path,
|
||||||
|
parse_descriptions=descriptions_from_wikipedia,
|
||||||
|
limit=limit,
|
||||||
|
)
|
||||||
|
logger.info("STEP 5: read gold entities" + message)
|
||||||
|
|
||||||
|
# STEP 6: creating the actual KB
|
||||||
# It takes ca. 30 minutes to pretrain the entity embeddings
|
# It takes ca. 30 minutes to pretrain the entity embeddings
|
||||||
print()
|
logger.info("STEP 6: creating the KB at {}".format(kb_path))
|
||||||
print(now(), "STEP 5: creating the KB at", loc_kb)
|
|
||||||
kb = kb_creator.create_kb(
|
kb = kb_creator.create_kb(
|
||||||
nlp=nlp,
|
nlp=nlp,
|
||||||
max_entities_per_alias=max_per_alias,
|
max_entities_per_alias=max_per_alias,
|
||||||
min_entity_freq=min_freq,
|
min_entity_freq=min_freq,
|
||||||
min_occ=min_pair,
|
min_occ=min_pair,
|
||||||
entity_def_output=loc_entity_defs,
|
entity_def_input=entity_defs_path,
|
||||||
entity_descr_output=loc_entity_desc,
|
entity_descr_path=entity_descr_path,
|
||||||
count_input=loc_entity_freq,
|
count_input=entity_freq_path,
|
||||||
prior_prob_input=loc_prior_prob,
|
prior_prob_input=prior_prob_path,
|
||||||
wikidata_input=wd_json,
|
|
||||||
entity_vector_length=entity_vector_length,
|
entity_vector_length=entity_vector_length,
|
||||||
limit=limit,
|
|
||||||
read_raw_data=read_raw,
|
|
||||||
)
|
)
|
||||||
if read_raw:
|
|
||||||
print(" - wrote entity definitions to", loc_entity_defs)
|
|
||||||
print(" - wrote writing entity descriptions to", loc_entity_desc)
|
|
||||||
|
|
||||||
kb.dump(loc_kb)
|
kb.dump(kb_path)
|
||||||
nlp.to_disk(output_dir / "nlp")
|
nlp.to_disk(output_dir / KB_MODEL_DIR)
|
||||||
|
|
||||||
print()
|
logger.info("Done!")
|
||||||
print(now(), "Done!")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
logging.basicConfig(level=logging.INFO, format=LOG_FORMAT)
|
||||||
plac.call(main)
|
plac.call(main)
|
||||||
|
|
|
@ -1,17 +1,19 @@
|
||||||
# coding: utf-8
|
# coding: utf-8
|
||||||
from __future__ import unicode_literals
|
from __future__ import unicode_literals
|
||||||
|
|
||||||
import bz2
|
import gzip
|
||||||
import json
|
import json
|
||||||
|
import logging
|
||||||
import datetime
|
import datetime
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
def read_wikidata_entities_json(wikidata_file, limit=None, to_print=False):
|
|
||||||
|
def read_wikidata_entities_json(wikidata_file, limit=None, to_print=False, lang="en", parse_descriptions=True):
|
||||||
# Read the JSON wiki data and parse out the entities. Takes about 7u30 to parse 55M lines.
|
# Read the JSON wiki data and parse out the entities. Takes about 7u30 to parse 55M lines.
|
||||||
# get latest-all.json.bz2 from https://dumps.wikimedia.org/wikidatawiki/entities/
|
# get latest-all.json.bz2 from https://dumps.wikimedia.org/wikidatawiki/entities/
|
||||||
|
|
||||||
lang = "en"
|
site_filter = '{}wiki'.format(lang)
|
||||||
site_filter = "enwiki"
|
|
||||||
|
|
||||||
# properties filter (currently disabled to get ALL data)
|
# properties filter (currently disabled to get ALL data)
|
||||||
prop_filter = dict()
|
prop_filter = dict()
|
||||||
|
@ -24,18 +26,15 @@ def read_wikidata_entities_json(wikidata_file, limit=None, to_print=False):
|
||||||
parse_properties = False
|
parse_properties = False
|
||||||
parse_sitelinks = True
|
parse_sitelinks = True
|
||||||
parse_labels = False
|
parse_labels = False
|
||||||
parse_descriptions = True
|
|
||||||
parse_aliases = False
|
parse_aliases = False
|
||||||
parse_claims = False
|
parse_claims = False
|
||||||
|
|
||||||
with bz2.open(wikidata_file, mode="rb") as file:
|
with gzip.open(wikidata_file, mode='rb') as file:
|
||||||
line = file.readline()
|
for cnt, line in enumerate(file):
|
||||||
cnt = 0
|
if limit and cnt >= limit:
|
||||||
while line and (not limit or cnt < limit):
|
break
|
||||||
if cnt % 1000000 == 0:
|
if cnt % 500000 == 0:
|
||||||
print(
|
logger.info("processed {} lines of WikiData dump".format(cnt))
|
||||||
datetime.datetime.now(), "processed", cnt, "lines of WikiData JSON dump"
|
|
||||||
)
|
|
||||||
clean_line = line.strip()
|
clean_line = line.strip()
|
||||||
if clean_line.endswith(b","):
|
if clean_line.endswith(b","):
|
||||||
clean_line = clean_line[:-1]
|
clean_line = clean_line[:-1]
|
||||||
|
@ -134,8 +133,19 @@ def read_wikidata_entities_json(wikidata_file, limit=None, to_print=False):
|
||||||
|
|
||||||
if to_print:
|
if to_print:
|
||||||
print()
|
print()
|
||||||
line = file.readline()
|
|
||||||
cnt += 1
|
|
||||||
print(datetime.datetime.now(), "processed", cnt, "lines of WikiData JSON dump")
|
|
||||||
|
|
||||||
return title_to_id, id_to_descr
|
return title_to_id, id_to_descr
|
||||||
|
|
||||||
|
|
||||||
|
def write_entity_files(entity_def_output, title_to_id):
|
||||||
|
with entity_def_output.open("w", encoding="utf8") as id_file:
|
||||||
|
id_file.write("WP_title" + "|" + "WD_id" + "\n")
|
||||||
|
for title, qid in title_to_id.items():
|
||||||
|
id_file.write(title + "|" + str(qid) + "\n")
|
||||||
|
|
||||||
|
|
||||||
|
def write_entity_description_files(entity_descr_output, id_to_descr):
|
||||||
|
with entity_descr_output.open("w", encoding="utf8") as descr_file:
|
||||||
|
descr_file.write("WD_id" + "|" + "description" + "\n")
|
||||||
|
for qid, descr in id_to_descr.items():
|
||||||
|
descr_file.write(str(qid) + "|" + descr + "\n")
|
||||||
|
|
|
@ -11,124 +11,84 @@ from https://dumps.wikimedia.org/enwiki/latest/
|
||||||
from __future__ import unicode_literals
|
from __future__ import unicode_literals
|
||||||
|
|
||||||
import random
|
import random
|
||||||
import datetime
|
import logging
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import plac
|
import plac
|
||||||
|
|
||||||
from bin.wiki_entity_linking import training_set_creator
|
from bin.wiki_entity_linking import training_set_creator
|
||||||
|
from bin.wiki_entity_linking import TRAINING_DATA_FILE, KB_MODEL_DIR, KB_FILE, LOG_FORMAT, OUTPUT_MODEL_DIR
|
||||||
|
from bin.wiki_entity_linking.entity_linker_evaluation import measure_performance, measure_baselines
|
||||||
|
from bin.wiki_entity_linking.kb_creator import read_nlp_kb
|
||||||
|
|
||||||
import spacy
|
|
||||||
from spacy.kb import KnowledgeBase
|
|
||||||
from spacy.util import minibatch, compounding
|
from spacy.util import minibatch, compounding
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
def now():
|
|
||||||
return datetime.datetime.now()
|
|
||||||
|
|
||||||
|
|
||||||
@plac.annotations(
|
@plac.annotations(
|
||||||
dir_kb=("Directory with KB, NLP and related files", "positional", None, Path),
|
dir_kb=("Directory with KB, NLP and related files", "positional", None, Path),
|
||||||
output_dir=("Output directory", "option", "o", Path),
|
output_dir=("Output directory", "option", "o", Path),
|
||||||
loc_training=("Location to training data", "option", "k", Path),
|
loc_training=("Location to training data", "option", "k", Path),
|
||||||
wp_xml=("Path to the downloaded Wikipedia XML dump.", "option", "w", Path),
|
|
||||||
epochs=("Number of training iterations (default 10)", "option", "e", int),
|
epochs=("Number of training iterations (default 10)", "option", "e", int),
|
||||||
dropout=("Dropout to prevent overfitting (default 0.5)", "option", "p", float),
|
dropout=("Dropout to prevent overfitting (default 0.5)", "option", "p", float),
|
||||||
lr=("Learning rate (default 0.005)", "option", "n", float),
|
lr=("Learning rate (default 0.005)", "option", "n", float),
|
||||||
l2=("L2 regularization", "option", "r", float),
|
l2=("L2 regularization", "option", "r", float),
|
||||||
train_inst=("# training instances (default 90% of all)", "option", "t", int),
|
train_inst=("# training instances (default 90% of all)", "option", "t", int),
|
||||||
dev_inst=("# test instances (default 10% of all)", "option", "d", int),
|
dev_inst=("# test instances (default 10% of all)", "option", "d", int),
|
||||||
limit=("Optional threshold to limit lines read from WP dump", "option", "l", int),
|
|
||||||
)
|
)
|
||||||
def main(
|
def main(
|
||||||
dir_kb,
|
dir_kb,
|
||||||
output_dir=None,
|
output_dir=None,
|
||||||
loc_training=None,
|
loc_training=None,
|
||||||
wp_xml=None,
|
|
||||||
epochs=10,
|
epochs=10,
|
||||||
dropout=0.5,
|
dropout=0.5,
|
||||||
lr=0.005,
|
lr=0.005,
|
||||||
l2=1e-6,
|
l2=1e-6,
|
||||||
train_inst=None,
|
train_inst=None,
|
||||||
dev_inst=None,
|
dev_inst=None,
|
||||||
limit=None,
|
|
||||||
):
|
):
|
||||||
print(now(), "Creating Entity Linker with Wikipedia and WikiData")
|
logger.info("Creating Entity Linker with Wikipedia and WikiData")
|
||||||
print()
|
|
||||||
|
output_dir = Path(output_dir) if output_dir else dir_kb
|
||||||
|
training_path = loc_training if loc_training else output_dir / TRAINING_DATA_FILE
|
||||||
|
nlp_dir = dir_kb / KB_MODEL_DIR
|
||||||
|
kb_path = output_dir / KB_FILE
|
||||||
|
nlp_output_dir = output_dir / OUTPUT_MODEL_DIR
|
||||||
|
|
||||||
# STEP 0: set up IO
|
# STEP 0: set up IO
|
||||||
if output_dir and not output_dir.exists():
|
if not output_dir.exists():
|
||||||
output_dir.mkdir()
|
output_dir.mkdir()
|
||||||
|
|
||||||
# STEP 1 : load the NLP object
|
# STEP 1 : load the NLP object
|
||||||
nlp_dir = dir_kb / "nlp"
|
logger.info("STEP 1: loading model from {}".format(nlp_dir))
|
||||||
print(now(), "STEP 1: loading model from", nlp_dir)
|
nlp, kb = read_nlp_kb(nlp_dir, kb_path)
|
||||||
nlp = spacy.load(nlp_dir)
|
|
||||||
|
|
||||||
# check that there is a NER component in the pipeline
|
# check that there is a NER component in the pipeline
|
||||||
if "ner" not in nlp.pipe_names:
|
if "ner" not in nlp.pipe_names:
|
||||||
raise ValueError("The `nlp` object should have a pre-trained `ner` component.")
|
raise ValueError("The `nlp` object should have a pre-trained `ner` component.")
|
||||||
|
|
||||||
# STEP 2 : read the KB
|
# STEP 2: create a training dataset from WP
|
||||||
print()
|
logger.info("STEP 2: reading training dataset from {}".format(training_path))
|
||||||
print(now(), "STEP 2: reading the KB from", dir_kb / "kb")
|
|
||||||
kb = KnowledgeBase(vocab=nlp.vocab)
|
|
||||||
kb.load_bulk(dir_kb / "kb")
|
|
||||||
|
|
||||||
# STEP 3: create a training dataset from WP
|
|
||||||
print()
|
|
||||||
if loc_training:
|
|
||||||
print(now(), "STEP 3: reading training dataset from", loc_training)
|
|
||||||
else:
|
|
||||||
if not wp_xml:
|
|
||||||
raise ValueError(
|
|
||||||
"Either provide a path to a preprocessed training directory, "
|
|
||||||
"or to the original Wikipedia XML dump."
|
|
||||||
)
|
|
||||||
|
|
||||||
if output_dir:
|
|
||||||
loc_training = output_dir / "training_data"
|
|
||||||
else:
|
|
||||||
loc_training = dir_kb / "training_data"
|
|
||||||
if not loc_training.exists():
|
|
||||||
loc_training.mkdir()
|
|
||||||
print(now(), "STEP 3: creating training dataset at", loc_training)
|
|
||||||
|
|
||||||
if limit is not None:
|
|
||||||
print("Warning: reading only", limit, "lines of Wikipedia dump.")
|
|
||||||
|
|
||||||
loc_entity_defs = dir_kb / "entity_defs.csv"
|
|
||||||
training_set_creator.create_training(
|
|
||||||
wikipedia_input=wp_xml,
|
|
||||||
entity_def_input=loc_entity_defs,
|
|
||||||
training_output=loc_training,
|
|
||||||
limit=limit,
|
|
||||||
)
|
|
||||||
|
|
||||||
# STEP 4: parse the training data
|
|
||||||
print()
|
|
||||||
print(now(), "STEP 4: parse the training & evaluation data")
|
|
||||||
|
|
||||||
# for training, get pos & neg instances that correspond to entries in the kb
|
|
||||||
print("Parsing training data, limit =", train_inst)
|
|
||||||
train_data = training_set_creator.read_training(
|
train_data = training_set_creator.read_training(
|
||||||
nlp=nlp, training_dir=loc_training, dev=False, limit=train_inst, kb=kb
|
nlp=nlp,
|
||||||
|
entity_file_path=training_path,
|
||||||
|
dev=False,
|
||||||
|
limit=train_inst,
|
||||||
|
kb=kb,
|
||||||
)
|
)
|
||||||
|
|
||||||
print("Training on", len(train_data), "articles")
|
|
||||||
print()
|
|
||||||
|
|
||||||
print("Parsing dev testing data, limit =", dev_inst)
|
|
||||||
# for testing, get all pos instances, whether or not they are in the kb
|
# for testing, get all pos instances, whether or not they are in the kb
|
||||||
dev_data = training_set_creator.read_training(
|
dev_data = training_set_creator.read_training(
|
||||||
nlp=nlp, training_dir=loc_training, dev=True, limit=dev_inst, kb=None
|
nlp=nlp,
|
||||||
|
entity_file_path=training_path,
|
||||||
|
dev=True,
|
||||||
|
limit=dev_inst,
|
||||||
|
kb=kb,
|
||||||
)
|
)
|
||||||
|
|
||||||
print("Dev testing on", len(dev_data), "articles")
|
# STEP 3: create and train the entity linking pipe
|
||||||
print()
|
logger.info("STEP 3: training Entity Linking pipe")
|
||||||
|
|
||||||
# STEP 5: create and train the entity linking pipe
|
|
||||||
print()
|
|
||||||
print(now(), "STEP 5: training Entity Linking pipe")
|
|
||||||
|
|
||||||
el_pipe = nlp.create_pipe(
|
el_pipe = nlp.create_pipe(
|
||||||
name="entity_linker", config={"pretrained_vectors": nlp.vocab.vectors.name}
|
name="entity_linker", config={"pretrained_vectors": nlp.vocab.vectors.name}
|
||||||
|
@ -142,275 +102,70 @@ def main(
|
||||||
optimizer.learn_rate = lr
|
optimizer.learn_rate = lr
|
||||||
optimizer.L2 = l2
|
optimizer.L2 = l2
|
||||||
|
|
||||||
if not train_data:
|
logger.info("Training on {} articles".format(len(train_data)))
|
||||||
print("Did not find any training data")
|
logger.info("Dev testing on {} articles".format(len(dev_data)))
|
||||||
else:
|
|
||||||
for itn in range(epochs):
|
|
||||||
random.shuffle(train_data)
|
|
||||||
losses = {}
|
|
||||||
batches = minibatch(train_data, size=compounding(4.0, 128.0, 1.001))
|
|
||||||
batchnr = 0
|
|
||||||
|
|
||||||
with nlp.disable_pipes(*other_pipes):
|
dev_baseline_accuracies = measure_baselines(
|
||||||
for batch in batches:
|
dev_data, kb
|
||||||
try:
|
|
||||||
docs, golds = zip(*batch)
|
|
||||||
nlp.update(
|
|
||||||
docs=docs,
|
|
||||||
golds=golds,
|
|
||||||
sgd=optimizer,
|
|
||||||
drop=dropout,
|
|
||||||
losses=losses,
|
|
||||||
)
|
|
||||||
batchnr += 1
|
|
||||||
except Exception as e:
|
|
||||||
print("Error updating batch:", e)
|
|
||||||
|
|
||||||
if batchnr > 0:
|
|
||||||
el_pipe.cfg["incl_context"] = True
|
|
||||||
el_pipe.cfg["incl_prior"] = True
|
|
||||||
dev_acc_context, _ = _measure_acc(dev_data, el_pipe)
|
|
||||||
losses["entity_linker"] = losses["entity_linker"] / batchnr
|
|
||||||
print(
|
|
||||||
"Epoch, train loss",
|
|
||||||
itn,
|
|
||||||
round(losses["entity_linker"], 2),
|
|
||||||
" / dev accuracy avg",
|
|
||||||
round(dev_acc_context, 3),
|
|
||||||
)
|
|
||||||
|
|
||||||
# STEP 6: measure the performance of our trained pipe on an independent dev set
|
|
||||||
print()
|
|
||||||
if len(dev_data):
|
|
||||||
print()
|
|
||||||
print(now(), "STEP 6: performance measurement of Entity Linking pipe")
|
|
||||||
print()
|
|
||||||
|
|
||||||
counts, acc_r, acc_r_d, acc_p, acc_p_d, acc_o, acc_o_d = _measure_baselines(
|
|
||||||
dev_data, kb
|
|
||||||
)
|
|
||||||
print("dev counts:", sorted(counts.items(), key=lambda x: x[0]))
|
|
||||||
|
|
||||||
oracle_by_label = [(x, round(y, 3)) for x, y in acc_o_d.items()]
|
|
||||||
print("dev accuracy oracle:", round(acc_o, 3), oracle_by_label)
|
|
||||||
|
|
||||||
random_by_label = [(x, round(y, 3)) for x, y in acc_r_d.items()]
|
|
||||||
print("dev accuracy random:", round(acc_r, 3), random_by_label)
|
|
||||||
|
|
||||||
prior_by_label = [(x, round(y, 3)) for x, y in acc_p_d.items()]
|
|
||||||
print("dev accuracy prior:", round(acc_p, 3), prior_by_label)
|
|
||||||
|
|
||||||
# using only context
|
|
||||||
el_pipe.cfg["incl_context"] = True
|
|
||||||
el_pipe.cfg["incl_prior"] = False
|
|
||||||
dev_acc_context, dev_acc_cont_d = _measure_acc(dev_data, el_pipe)
|
|
||||||
context_by_label = [(x, round(y, 3)) for x, y in dev_acc_cont_d.items()]
|
|
||||||
print("dev accuracy context:", round(dev_acc_context, 3), context_by_label)
|
|
||||||
|
|
||||||
# measuring combined accuracy (prior + context)
|
|
||||||
el_pipe.cfg["incl_context"] = True
|
|
||||||
el_pipe.cfg["incl_prior"] = True
|
|
||||||
dev_acc_combo, dev_acc_combo_d = _measure_acc(dev_data, el_pipe)
|
|
||||||
combo_by_label = [(x, round(y, 3)) for x, y in dev_acc_combo_d.items()]
|
|
||||||
print("dev accuracy prior+context:", round(dev_acc_combo, 3), combo_by_label)
|
|
||||||
|
|
||||||
# STEP 7: apply the EL pipe on a toy example
|
|
||||||
print()
|
|
||||||
print(now(), "STEP 7: applying Entity Linking to toy example")
|
|
||||||
print()
|
|
||||||
run_el_toy_example(nlp=nlp)
|
|
||||||
|
|
||||||
# STEP 8: write the NLP pipeline (including entity linker) to file
|
|
||||||
if output_dir:
|
|
||||||
print()
|
|
||||||
nlp_loc = output_dir / "nlp"
|
|
||||||
print(now(), "STEP 8: Writing trained NLP to", nlp_loc)
|
|
||||||
nlp.to_disk(nlp_loc)
|
|
||||||
print()
|
|
||||||
|
|
||||||
print()
|
|
||||||
print(now(), "Done!")
|
|
||||||
|
|
||||||
|
|
||||||
def _measure_acc(data, el_pipe=None, error_analysis=False):
|
|
||||||
# If the docs in the data require further processing with an entity linker, set el_pipe
|
|
||||||
correct_by_label = dict()
|
|
||||||
incorrect_by_label = dict()
|
|
||||||
|
|
||||||
docs = [d for d, g in data if len(d) > 0]
|
|
||||||
if el_pipe is not None:
|
|
||||||
docs = list(el_pipe.pipe(docs))
|
|
||||||
golds = [g for d, g in data if len(d) > 0]
|
|
||||||
|
|
||||||
for doc, gold in zip(docs, golds):
|
|
||||||
try:
|
|
||||||
correct_entries_per_article = dict()
|
|
||||||
for entity, kb_dict in gold.links.items():
|
|
||||||
start, end = entity
|
|
||||||
# only evaluating on positive examples
|
|
||||||
for gold_kb, value in kb_dict.items():
|
|
||||||
if value:
|
|
||||||
offset = _offset(start, end)
|
|
||||||
correct_entries_per_article[offset] = gold_kb
|
|
||||||
|
|
||||||
for ent in doc.ents:
|
|
||||||
ent_label = ent.label_
|
|
||||||
pred_entity = ent.kb_id_
|
|
||||||
start = ent.start_char
|
|
||||||
end = ent.end_char
|
|
||||||
offset = _offset(start, end)
|
|
||||||
gold_entity = correct_entries_per_article.get(offset, None)
|
|
||||||
# the gold annotations are not complete so we can't evaluate missing annotations as 'wrong'
|
|
||||||
if gold_entity is not None:
|
|
||||||
if gold_entity == pred_entity:
|
|
||||||
correct = correct_by_label.get(ent_label, 0)
|
|
||||||
correct_by_label[ent_label] = correct + 1
|
|
||||||
else:
|
|
||||||
incorrect = incorrect_by_label.get(ent_label, 0)
|
|
||||||
incorrect_by_label[ent_label] = incorrect + 1
|
|
||||||
if error_analysis:
|
|
||||||
print(ent.text, "in", doc)
|
|
||||||
print(
|
|
||||||
"Predicted",
|
|
||||||
pred_entity,
|
|
||||||
"should have been",
|
|
||||||
gold_entity,
|
|
||||||
)
|
|
||||||
print()
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print("Error assessing accuracy", e)
|
|
||||||
|
|
||||||
acc, acc_by_label = calculate_acc(correct_by_label, incorrect_by_label)
|
|
||||||
return acc, acc_by_label
|
|
||||||
|
|
||||||
|
|
||||||
def _measure_baselines(data, kb):
|
|
||||||
# Measure 3 performance baselines: random selection, prior probabilities, and 'oracle' prediction for upper bound
|
|
||||||
counts_d = dict()
|
|
||||||
|
|
||||||
random_correct_d = dict()
|
|
||||||
random_incorrect_d = dict()
|
|
||||||
|
|
||||||
oracle_correct_d = dict()
|
|
||||||
oracle_incorrect_d = dict()
|
|
||||||
|
|
||||||
prior_correct_d = dict()
|
|
||||||
prior_incorrect_d = dict()
|
|
||||||
|
|
||||||
docs = [d for d, g in data if len(d) > 0]
|
|
||||||
golds = [g for d, g in data if len(d) > 0]
|
|
||||||
|
|
||||||
for doc, gold in zip(docs, golds):
|
|
||||||
try:
|
|
||||||
correct_entries_per_article = dict()
|
|
||||||
for entity, kb_dict in gold.links.items():
|
|
||||||
start, end = entity
|
|
||||||
for gold_kb, value in kb_dict.items():
|
|
||||||
# only evaluating on positive examples
|
|
||||||
if value:
|
|
||||||
offset = _offset(start, end)
|
|
||||||
correct_entries_per_article[offset] = gold_kb
|
|
||||||
|
|
||||||
for ent in doc.ents:
|
|
||||||
label = ent.label_
|
|
||||||
start = ent.start_char
|
|
||||||
end = ent.end_char
|
|
||||||
offset = _offset(start, end)
|
|
||||||
gold_entity = correct_entries_per_article.get(offset, None)
|
|
||||||
|
|
||||||
# the gold annotations are not complete so we can't evaluate missing annotations as 'wrong'
|
|
||||||
if gold_entity is not None:
|
|
||||||
counts_d[label] = counts_d.get(label, 0) + 1
|
|
||||||
candidates = kb.get_candidates(ent.text)
|
|
||||||
oracle_candidate = ""
|
|
||||||
best_candidate = ""
|
|
||||||
random_candidate = ""
|
|
||||||
if candidates:
|
|
||||||
scores = []
|
|
||||||
|
|
||||||
for c in candidates:
|
|
||||||
scores.append(c.prior_prob)
|
|
||||||
if c.entity_ == gold_entity:
|
|
||||||
oracle_candidate = c.entity_
|
|
||||||
|
|
||||||
best_index = scores.index(max(scores))
|
|
||||||
best_candidate = candidates[best_index].entity_
|
|
||||||
random_candidate = random.choice(candidates).entity_
|
|
||||||
|
|
||||||
if gold_entity == best_candidate:
|
|
||||||
prior_correct_d[label] = prior_correct_d.get(label, 0) + 1
|
|
||||||
else:
|
|
||||||
prior_incorrect_d[label] = prior_incorrect_d.get(label, 0) + 1
|
|
||||||
|
|
||||||
if gold_entity == random_candidate:
|
|
||||||
random_correct_d[label] = random_correct_d.get(label, 0) + 1
|
|
||||||
else:
|
|
||||||
random_incorrect_d[label] = random_incorrect_d.get(label, 0) + 1
|
|
||||||
|
|
||||||
if gold_entity == oracle_candidate:
|
|
||||||
oracle_correct_d[label] = oracle_correct_d.get(label, 0) + 1
|
|
||||||
else:
|
|
||||||
oracle_incorrect_d[label] = oracle_incorrect_d.get(label, 0) + 1
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print("Error assessing accuracy", e)
|
|
||||||
|
|
||||||
acc_prior, acc_prior_d = calculate_acc(prior_correct_d, prior_incorrect_d)
|
|
||||||
acc_rand, acc_rand_d = calculate_acc(random_correct_d, random_incorrect_d)
|
|
||||||
acc_oracle, acc_oracle_d = calculate_acc(oracle_correct_d, oracle_incorrect_d)
|
|
||||||
|
|
||||||
return (
|
|
||||||
counts_d,
|
|
||||||
acc_rand,
|
|
||||||
acc_rand_d,
|
|
||||||
acc_prior,
|
|
||||||
acc_prior_d,
|
|
||||||
acc_oracle,
|
|
||||||
acc_oracle_d,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
logger.info("Dev Baseline Accuracies:")
|
||||||
|
logger.info(dev_baseline_accuracies.report_accuracy("random"))
|
||||||
|
logger.info(dev_baseline_accuracies.report_accuracy("prior"))
|
||||||
|
logger.info(dev_baseline_accuracies.report_accuracy("oracle"))
|
||||||
|
|
||||||
def _offset(start, end):
|
for itn in range(epochs):
|
||||||
return "{}_{}".format(start, end)
|
random.shuffle(train_data)
|
||||||
|
losses = {}
|
||||||
|
batches = minibatch(train_data, size=compounding(4.0, 128.0, 1.001))
|
||||||
|
batchnr = 0
|
||||||
|
|
||||||
|
with nlp.disable_pipes(*other_pipes):
|
||||||
|
for batch in batches:
|
||||||
|
try:
|
||||||
|
docs, golds = zip(*batch)
|
||||||
|
nlp.update(
|
||||||
|
docs=docs,
|
||||||
|
golds=golds,
|
||||||
|
sgd=optimizer,
|
||||||
|
drop=dropout,
|
||||||
|
losses=losses,
|
||||||
|
)
|
||||||
|
batchnr += 1
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("Error updating batch:" + str(e))
|
||||||
|
if batchnr > 0:
|
||||||
|
logging.info("Epoch {}, train loss {}".format(itn, round(losses["entity_linker"] / batchnr, 2)))
|
||||||
|
measure_performance(dev_data, kb, el_pipe)
|
||||||
|
|
||||||
def calculate_acc(correct_by_label, incorrect_by_label):
|
# STEP 4: measure the performance of our trained pipe on an independent dev set
|
||||||
acc_by_label = dict()
|
logger.info("STEP 4: performance measurement of Entity Linking pipe")
|
||||||
total_correct = 0
|
measure_performance(dev_data, kb, el_pipe)
|
||||||
total_incorrect = 0
|
|
||||||
all_keys = set()
|
# STEP 5: apply the EL pipe on a toy example
|
||||||
all_keys.update(correct_by_label.keys())
|
logger.info("STEP 5: applying Entity Linking to toy example")
|
||||||
all_keys.update(incorrect_by_label.keys())
|
run_el_toy_example(nlp=nlp)
|
||||||
for label in sorted(all_keys):
|
|
||||||
correct = correct_by_label.get(label, 0)
|
if output_dir:
|
||||||
incorrect = incorrect_by_label.get(label, 0)
|
# STEP 6: write the NLP pipeline (including entity linker) to file
|
||||||
total_correct += correct
|
logger.info("STEP 6: Writing trained NLP to {}".format(nlp_output_dir))
|
||||||
total_incorrect += incorrect
|
nlp.to_disk(nlp_output_dir)
|
||||||
if correct == incorrect == 0:
|
|
||||||
acc_by_label[label] = 0
|
logger.info("Done!")
|
||||||
else:
|
|
||||||
acc_by_label[label] = correct / (correct + incorrect)
|
|
||||||
acc = 0
|
|
||||||
if not (total_correct == total_incorrect == 0):
|
|
||||||
acc = total_correct / (total_correct + total_incorrect)
|
|
||||||
return acc, acc_by_label
|
|
||||||
|
|
||||||
|
|
||||||
def check_kb(kb):
|
def check_kb(kb):
|
||||||
for mention in ("Bush", "Douglas Adams", "Homer", "Brazil", "China"):
|
for mention in ("Bush", "Douglas Adams", "Homer", "Brazil", "China"):
|
||||||
candidates = kb.get_candidates(mention)
|
candidates = kb.get_candidates(mention)
|
||||||
|
|
||||||
print("generating candidates for " + mention + " :")
|
logger.info("generating candidates for " + mention + " :")
|
||||||
for c in candidates:
|
for c in candidates:
|
||||||
print(
|
logger.info(" ".join[
|
||||||
" ",
|
str(c.prior_prob),
|
||||||
c.prior_prob,
|
|
||||||
c.alias_,
|
c.alias_,
|
||||||
"-->",
|
"-->",
|
||||||
c.entity_ + " (freq=" + str(c.entity_freq) + ")",
|
c.entity_ + " (freq=" + str(c.entity_freq) + ")"
|
||||||
)
|
])
|
||||||
print()
|
|
||||||
|
|
||||||
|
|
||||||
def run_el_toy_example(nlp):
|
def run_el_toy_example(nlp):
|
||||||
|
@ -421,11 +176,11 @@ def run_el_toy_example(nlp):
|
||||||
"but Dougledydoug doesn't write about George Washington or Homer Simpson."
|
"but Dougledydoug doesn't write about George Washington or Homer Simpson."
|
||||||
)
|
)
|
||||||
doc = nlp(text)
|
doc = nlp(text)
|
||||||
print(text)
|
logger.info(text)
|
||||||
for ent in doc.ents:
|
for ent in doc.ents:
|
||||||
print(" ent", ent.text, ent.label_, ent.kb_id_)
|
logger.info(" ".join(["ent", ent.text, ent.label_, ent.kb_id_]))
|
||||||
print()
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
logging.basicConfig(level=logging.INFO, format=LOG_FORMAT)
|
||||||
plac.call(main)
|
plac.call(main)
|
||||||
|
|
|
@ -5,6 +5,9 @@ import re
|
||||||
import bz2
|
import bz2
|
||||||
import csv
|
import csv
|
||||||
import datetime
|
import datetime
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from bin.wiki_entity_linking import LOG_FORMAT
|
||||||
|
|
||||||
"""
|
"""
|
||||||
Process a Wikipedia dump to calculate entity frequencies and prior probabilities in combination with certain mentions.
|
Process a Wikipedia dump to calculate entity frequencies and prior probabilities in combination with certain mentions.
|
||||||
|
@ -13,6 +16,9 @@ Write these results to file for downstream KB and training data generation.
|
||||||
|
|
||||||
map_alias_to_link = dict()
|
map_alias_to_link = dict()
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
# these will/should be matched ignoring case
|
# these will/should be matched ignoring case
|
||||||
wiki_namespaces = [
|
wiki_namespaces = [
|
||||||
"b",
|
"b",
|
||||||
|
@ -116,10 +122,6 @@ for ns in wiki_namespaces:
|
||||||
ns_regex = re.compile(ns_regex, re.IGNORECASE)
|
ns_regex = re.compile(ns_regex, re.IGNORECASE)
|
||||||
|
|
||||||
|
|
||||||
def now():
|
|
||||||
return datetime.datetime.now()
|
|
||||||
|
|
||||||
|
|
||||||
def read_prior_probs(wikipedia_input, prior_prob_output, limit=None):
|
def read_prior_probs(wikipedia_input, prior_prob_output, limit=None):
|
||||||
"""
|
"""
|
||||||
Read the XML wikipedia data and parse out intra-wiki links to estimate prior probabilities.
|
Read the XML wikipedia data and parse out intra-wiki links to estimate prior probabilities.
|
||||||
|
@ -131,7 +133,7 @@ def read_prior_probs(wikipedia_input, prior_prob_output, limit=None):
|
||||||
cnt = 0
|
cnt = 0
|
||||||
while line and (not limit or cnt < limit):
|
while line and (not limit or cnt < limit):
|
||||||
if cnt % 25000000 == 0:
|
if cnt % 25000000 == 0:
|
||||||
print(now(), "processed", cnt, "lines of Wikipedia XML dump")
|
logger.info("processed {} lines of Wikipedia XML dump".format(cnt))
|
||||||
clean_line = line.strip().decode("utf-8")
|
clean_line = line.strip().decode("utf-8")
|
||||||
|
|
||||||
aliases, entities, normalizations = get_wp_links(clean_line)
|
aliases, entities, normalizations = get_wp_links(clean_line)
|
||||||
|
@ -141,7 +143,7 @@ def read_prior_probs(wikipedia_input, prior_prob_output, limit=None):
|
||||||
|
|
||||||
line = file.readline()
|
line = file.readline()
|
||||||
cnt += 1
|
cnt += 1
|
||||||
print(now(), "processed", cnt, "lines of Wikipedia XML dump")
|
logger.info("processed {} lines of Wikipedia XML dump".format(cnt))
|
||||||
|
|
||||||
# write all aliases and their entities and count occurrences to file
|
# write all aliases and their entities and count occurrences to file
|
||||||
with prior_prob_output.open("w", encoding="utf8") as outputfile:
|
with prior_prob_output.open("w", encoding="utf8") as outputfile:
|
||||||
|
|
|
@ -455,6 +455,8 @@ class Errors(object):
|
||||||
E158 = ("Can't add table '{name}' to lookups because it already exists.")
|
E158 = ("Can't add table '{name}' to lookups because it already exists.")
|
||||||
E159 = ("Can't find table '{name}' in lookups. Available tables: {tables}")
|
E159 = ("Can't find table '{name}' in lookups. Available tables: {tables}")
|
||||||
E160 = ("Can't find language data file: {path}")
|
E160 = ("Can't find language data file: {path}")
|
||||||
|
E161 = ("Found an internal inconsistency when predicting entity links. "
|
||||||
|
"This is likely a bug in spaCy, so feel free to open an issue.")
|
||||||
|
|
||||||
|
|
||||||
@add_codes
|
@add_codes
|
||||||
|
|
|
@ -1287,7 +1287,7 @@ class EntityLinker(Pipe):
|
||||||
# this will set all prior probabilities to 0 if they should be excluded from the model
|
# this will set all prior probabilities to 0 if they should be excluded from the model
|
||||||
prior_probs = xp.asarray([c.prior_prob for c in candidates])
|
prior_probs = xp.asarray([c.prior_prob for c in candidates])
|
||||||
if not self.cfg.get("incl_prior", True):
|
if not self.cfg.get("incl_prior", True):
|
||||||
prior_probs = xp.asarray([[0.0] for c in candidates])
|
prior_probs = xp.asarray([0.0 for c in candidates])
|
||||||
scores = prior_probs
|
scores = prior_probs
|
||||||
|
|
||||||
# add in similarity from the context
|
# add in similarity from the context
|
||||||
|
@ -1300,6 +1300,8 @@ class EntityLinker(Pipe):
|
||||||
|
|
||||||
# cosine similarity
|
# cosine similarity
|
||||||
sims = xp.dot(entity_encodings, context_enc_t) / (norm_1 * norm_2)
|
sims = xp.dot(entity_encodings, context_enc_t) / (norm_1 * norm_2)
|
||||||
|
if sims.shape != prior_probs.shape:
|
||||||
|
raise ValueError(Errors.E161)
|
||||||
scores = prior_probs + sims - (prior_probs*sims)
|
scores = prior_probs + sims - (prior_probs*sims)
|
||||||
|
|
||||||
# TODO: thresholding
|
# TODO: thresholding
|
||||||
|
|
Loading…
Reference in New Issue
Block a user