mirror of
https://github.com/explosion/spaCy.git
synced 2024-11-13 13:17:06 +03:00
Update NEL examples and documentation (#5370)
* simplify creation of KB by skipping dim reduction * small fixes to train EL example script * add KB creation and NEL training example scripts to example section * update descriptions of example scripts in the documentation * moving wiki_entity_linking folder from bin to projects * remove test for wiki NEL functionality that is being moved # Conflicts: # bin/wiki_entity_linking/wikipedia_processor.py
This commit is contained in:
parent
51207c9417
commit
cafe94ee04
|
@ -1,37 +0,0 @@
|
||||||
## Entity Linking with Wikipedia and Wikidata
|
|
||||||
|
|
||||||
### Step 1: Create a Knowledge Base (KB) and training data
|
|
||||||
|
|
||||||
Run `wikidata_pretrain_kb.py`
|
|
||||||
* This takes as input the locations of a **Wikipedia and a Wikidata dump**, and produces a **KB directory** + **training file**
|
|
||||||
* WikiData: get `latest-all.json.bz2` from https://dumps.wikimedia.org/wikidatawiki/entities/
|
|
||||||
* Wikipedia: get `enwiki-latest-pages-articles-multistream.xml.bz2` from https://dumps.wikimedia.org/enwiki/latest/ (or for any other language)
|
|
||||||
* You can set the filtering parameters for KB construction:
|
|
||||||
* `max_per_alias` (`-a`): (max) number of candidate entities in the KB per alias/synonym
|
|
||||||
* `min_freq` (`-f`): threshold of number of times an entity should occur in the corpus to be included in the KB
|
|
||||||
* `min_pair` (`-c`): threshold of number of times an entity+alias combination should occur in the corpus to be included in the KB
|
|
||||||
* Further parameters to set:
|
|
||||||
* `descriptions_from_wikipedia` (`-wp`): whether to parse descriptions from Wikipedia (`True`) or Wikidata (`False`)
|
|
||||||
* `entity_vector_length` (`-v`): length of the pre-trained entity description vectors
|
|
||||||
* `lang` (`-la`): language for which to fetch Wikidata information (as the dump contains all languages)
|
|
||||||
|
|
||||||
Quick testing and rerunning:
|
|
||||||
* When trying out the pipeline for a quick test, set `limit_prior` (`-lp`), `limit_train` (`-lt`) and/or `limit_wd` (`-lw`) to read only parts of the dumps instead of everything.
|
|
||||||
* e.g. set `-lt 20000 -lp 2000 -lw 3000 -f 1`
|
|
||||||
* If you only want to (re)run certain parts of the pipeline, just remove the corresponding files and they will be recalculated or reparsed.
|
|
||||||
|
|
||||||
|
|
||||||
### Step 2: Train an Entity Linking model
|
|
||||||
|
|
||||||
Run `wikidata_train_entity_linker.py`
|
|
||||||
* This takes the **KB directory** produced by Step 1, and trains an **Entity Linking model**
|
|
||||||
* Specify the output directory (`-o`) in which the final, trained model will be saved
|
|
||||||
* You can set the learning parameters for the EL training:
|
|
||||||
* `epochs` (`-e`): number of training iterations
|
|
||||||
* `dropout` (`-p`): dropout rate
|
|
||||||
* `lr` (`-n`): learning rate
|
|
||||||
* `l2` (`-r`): L2 regularization
|
|
||||||
* Specify the number of training and dev testing articles with `train_articles` (`-t`) and `dev_articles` (`-d`) respectively
|
|
||||||
* If not specified, the full dataset will be processed - this may take a LONG time !
|
|
||||||
* Further parameters to set:
|
|
||||||
* `labels_discard` (`-l`): NER label types to discard during training
|
|
|
@ -1,12 +0,0 @@
|
||||||
TRAINING_DATA_FILE = "gold_entities.jsonl"
|
|
||||||
KB_FILE = "kb"
|
|
||||||
KB_MODEL_DIR = "nlp_kb"
|
|
||||||
OUTPUT_MODEL_DIR = "nlp"
|
|
||||||
|
|
||||||
PRIOR_PROB_PATH = "prior_prob.csv"
|
|
||||||
ENTITY_DEFS_PATH = "entity_defs.csv"
|
|
||||||
ENTITY_FREQ_PATH = "entity_freq.csv"
|
|
||||||
ENTITY_ALIAS_PATH = "entity_alias.csv"
|
|
||||||
ENTITY_DESCR_PATH = "entity_descriptions.csv"
|
|
||||||
|
|
||||||
LOG_FORMAT = '%(asctime)s - %(levelname)s - %(name)s - %(message)s'
|
|
|
@ -1,204 +0,0 @@
|
||||||
# coding: utf-8
|
|
||||||
from __future__ import unicode_literals
|
|
||||||
|
|
||||||
import logging
|
|
||||||
import random
|
|
||||||
from tqdm import tqdm
|
|
||||||
from collections import defaultdict
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class Metrics(object):
|
|
||||||
true_pos = 0
|
|
||||||
false_pos = 0
|
|
||||||
false_neg = 0
|
|
||||||
|
|
||||||
def update_results(self, true_entity, candidate):
|
|
||||||
candidate_is_correct = true_entity == candidate
|
|
||||||
|
|
||||||
# Assume that we have no labeled negatives in the data (i.e. cases where true_entity is "NIL")
|
|
||||||
# Therefore, if candidate_is_correct then we have a true positive and never a true negative.
|
|
||||||
self.true_pos += candidate_is_correct
|
|
||||||
self.false_neg += not candidate_is_correct
|
|
||||||
if candidate and candidate not in {"", "NIL"}:
|
|
||||||
# A wrong prediction (e.g. Q42 != Q3) counts both as a FP as well as a FN.
|
|
||||||
self.false_pos += not candidate_is_correct
|
|
||||||
|
|
||||||
def calculate_precision(self):
|
|
||||||
if self.true_pos == 0:
|
|
||||||
return 0.0
|
|
||||||
else:
|
|
||||||
return self.true_pos / (self.true_pos + self.false_pos)
|
|
||||||
|
|
||||||
def calculate_recall(self):
|
|
||||||
if self.true_pos == 0:
|
|
||||||
return 0.0
|
|
||||||
else:
|
|
||||||
return self.true_pos / (self.true_pos + self.false_neg)
|
|
||||||
|
|
||||||
def calculate_fscore(self):
|
|
||||||
p = self.calculate_precision()
|
|
||||||
r = self.calculate_recall()
|
|
||||||
if p + r == 0:
|
|
||||||
return 0.0
|
|
||||||
else:
|
|
||||||
return 2 * p * r / (p + r)
|
|
||||||
|
|
||||||
|
|
||||||
class EvaluationResults(object):
|
|
||||||
def __init__(self):
|
|
||||||
self.metrics = Metrics()
|
|
||||||
self.metrics_by_label = defaultdict(Metrics)
|
|
||||||
|
|
||||||
def update_metrics(self, ent_label, true_entity, candidate):
|
|
||||||
self.metrics.update_results(true_entity, candidate)
|
|
||||||
self.metrics_by_label[ent_label].update_results(true_entity, candidate)
|
|
||||||
|
|
||||||
def report_metrics(self, model_name):
|
|
||||||
model_str = model_name.title()
|
|
||||||
recall = self.metrics.calculate_recall()
|
|
||||||
precision = self.metrics.calculate_precision()
|
|
||||||
fscore = self.metrics.calculate_fscore()
|
|
||||||
return (
|
|
||||||
"{}: ".format(model_str)
|
|
||||||
+ "F-score = {} | ".format(round(fscore, 3))
|
|
||||||
+ "Recall = {} | ".format(round(recall, 3))
|
|
||||||
+ "Precision = {} | ".format(round(precision, 3))
|
|
||||||
+ "F-score by label = {}".format(
|
|
||||||
{k: v.calculate_fscore() for k, v in sorted(self.metrics_by_label.items())}
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class BaselineResults(object):
|
|
||||||
def __init__(self):
|
|
||||||
self.random = EvaluationResults()
|
|
||||||
self.prior = EvaluationResults()
|
|
||||||
self.oracle = EvaluationResults()
|
|
||||||
|
|
||||||
def report_performance(self, model):
|
|
||||||
results = getattr(self, model)
|
|
||||||
return results.report_metrics(model)
|
|
||||||
|
|
||||||
def update_baselines(
|
|
||||||
self,
|
|
||||||
true_entity,
|
|
||||||
ent_label,
|
|
||||||
random_candidate,
|
|
||||||
prior_candidate,
|
|
||||||
oracle_candidate,
|
|
||||||
):
|
|
||||||
self.oracle.update_metrics(ent_label, true_entity, oracle_candidate)
|
|
||||||
self.prior.update_metrics(ent_label, true_entity, prior_candidate)
|
|
||||||
self.random.update_metrics(ent_label, true_entity, random_candidate)
|
|
||||||
|
|
||||||
|
|
||||||
def measure_performance(dev_data, kb, el_pipe, baseline=True, context=True, dev_limit=None):
|
|
||||||
counts = dict()
|
|
||||||
baseline_results = BaselineResults()
|
|
||||||
context_results = EvaluationResults()
|
|
||||||
combo_results = EvaluationResults()
|
|
||||||
|
|
||||||
for doc, gold in tqdm(dev_data, total=dev_limit, leave=False, desc='Processing dev data'):
|
|
||||||
if len(doc) > 0:
|
|
||||||
correct_ents = dict()
|
|
||||||
for entity, kb_dict in gold.links.items():
|
|
||||||
start, end = entity
|
|
||||||
for gold_kb, value in kb_dict.items():
|
|
||||||
if value:
|
|
||||||
# only evaluating on positive examples
|
|
||||||
offset = _offset(start, end)
|
|
||||||
correct_ents[offset] = gold_kb
|
|
||||||
|
|
||||||
if baseline:
|
|
||||||
_add_baseline(baseline_results, counts, doc, correct_ents, kb)
|
|
||||||
|
|
||||||
if context:
|
|
||||||
# using only context
|
|
||||||
el_pipe.cfg["incl_context"] = True
|
|
||||||
el_pipe.cfg["incl_prior"] = False
|
|
||||||
_add_eval_result(context_results, doc, correct_ents, el_pipe)
|
|
||||||
|
|
||||||
# measuring combined accuracy (prior + context)
|
|
||||||
el_pipe.cfg["incl_context"] = True
|
|
||||||
el_pipe.cfg["incl_prior"] = True
|
|
||||||
_add_eval_result(combo_results, doc, correct_ents, el_pipe)
|
|
||||||
|
|
||||||
if baseline:
|
|
||||||
logger.info("Counts: {}".format({k: v for k, v in sorted(counts.items())}))
|
|
||||||
logger.info(baseline_results.report_performance("random"))
|
|
||||||
logger.info(baseline_results.report_performance("prior"))
|
|
||||||
logger.info(baseline_results.report_performance("oracle"))
|
|
||||||
|
|
||||||
if context:
|
|
||||||
logger.info(context_results.report_metrics("context only"))
|
|
||||||
logger.info(combo_results.report_metrics("context and prior"))
|
|
||||||
|
|
||||||
|
|
||||||
def _add_eval_result(results, doc, correct_ents, el_pipe):
|
|
||||||
"""
|
|
||||||
Evaluate the ent.kb_id_ annotations against the gold standard.
|
|
||||||
Only evaluate entities that overlap between gold and NER, to isolate the performance of the NEL.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
doc = el_pipe(doc)
|
|
||||||
for ent in doc.ents:
|
|
||||||
ent_label = ent.label_
|
|
||||||
start = ent.start_char
|
|
||||||
end = ent.end_char
|
|
||||||
offset = _offset(start, end)
|
|
||||||
gold_entity = correct_ents.get(offset, None)
|
|
||||||
# the gold annotations are not complete so we can't evaluate missing annotations as 'wrong'
|
|
||||||
if gold_entity is not None:
|
|
||||||
pred_entity = ent.kb_id_
|
|
||||||
results.update_metrics(ent_label, gold_entity, pred_entity)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logging.error("Error assessing accuracy " + str(e))
|
|
||||||
|
|
||||||
|
|
||||||
def _add_baseline(baseline_results, counts, doc, correct_ents, kb):
|
|
||||||
"""
|
|
||||||
Measure 3 performance baselines: random selection, prior probabilities, and 'oracle' prediction for upper bound.
|
|
||||||
Only evaluate entities that overlap between gold and NER, to isolate the performance of the NEL.
|
|
||||||
"""
|
|
||||||
for ent in doc.ents:
|
|
||||||
ent_label = ent.label_
|
|
||||||
start = ent.start_char
|
|
||||||
end = ent.end_char
|
|
||||||
offset = _offset(start, end)
|
|
||||||
gold_entity = correct_ents.get(offset, None)
|
|
||||||
|
|
||||||
# the gold annotations are not complete so we can't evaluate missing annotations as 'wrong'
|
|
||||||
if gold_entity is not None:
|
|
||||||
candidates = kb.get_candidates(ent.text)
|
|
||||||
oracle_candidate = ""
|
|
||||||
prior_candidate = ""
|
|
||||||
random_candidate = ""
|
|
||||||
if candidates:
|
|
||||||
scores = []
|
|
||||||
|
|
||||||
for c in candidates:
|
|
||||||
scores.append(c.prior_prob)
|
|
||||||
if c.entity_ == gold_entity:
|
|
||||||
oracle_candidate = c.entity_
|
|
||||||
|
|
||||||
best_index = scores.index(max(scores))
|
|
||||||
prior_candidate = candidates[best_index].entity_
|
|
||||||
random_candidate = random.choice(candidates).entity_
|
|
||||||
|
|
||||||
current_count = counts.get(ent_label, 0)
|
|
||||||
counts[ent_label] = current_count+1
|
|
||||||
|
|
||||||
baseline_results.update_baselines(
|
|
||||||
gold_entity,
|
|
||||||
ent_label,
|
|
||||||
random_candidate,
|
|
||||||
prior_candidate,
|
|
||||||
oracle_candidate,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _offset(start, end):
|
|
||||||
return "{}_{}".format(start, end)
|
|
|
@ -1,161 +0,0 @@
|
||||||
# coding: utf-8
|
|
||||||
from __future__ import unicode_literals
|
|
||||||
|
|
||||||
import logging
|
|
||||||
|
|
||||||
from spacy.kb import KnowledgeBase
|
|
||||||
|
|
||||||
from bin.wiki_entity_linking.train_descriptions import EntityEncoder
|
|
||||||
from bin.wiki_entity_linking import wiki_io as io
|
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
def create_kb(
|
|
||||||
nlp,
|
|
||||||
max_entities_per_alias,
|
|
||||||
min_entity_freq,
|
|
||||||
min_occ,
|
|
||||||
entity_def_path,
|
|
||||||
entity_descr_path,
|
|
||||||
entity_alias_path,
|
|
||||||
entity_freq_path,
|
|
||||||
prior_prob_path,
|
|
||||||
entity_vector_length,
|
|
||||||
):
|
|
||||||
# Create the knowledge base from Wikidata entries
|
|
||||||
kb = KnowledgeBase(vocab=nlp.vocab, entity_vector_length=entity_vector_length)
|
|
||||||
entity_list, filtered_title_to_id = _define_entities(nlp, kb, entity_def_path, entity_descr_path, min_entity_freq, entity_freq_path, entity_vector_length)
|
|
||||||
_define_aliases(kb, entity_alias_path, entity_list, filtered_title_to_id, max_entities_per_alias, min_occ, prior_prob_path)
|
|
||||||
return kb
|
|
||||||
|
|
||||||
|
|
||||||
def _define_entities(nlp, kb, entity_def_path, entity_descr_path, min_entity_freq, entity_freq_path, entity_vector_length):
|
|
||||||
# read the mappings from file
|
|
||||||
title_to_id = io.read_title_to_id(entity_def_path)
|
|
||||||
id_to_descr = io.read_id_to_descr(entity_descr_path)
|
|
||||||
|
|
||||||
# check the length of the nlp vectors
|
|
||||||
if "vectors" in nlp.meta and nlp.vocab.vectors.size:
|
|
||||||
input_dim = nlp.vocab.vectors_length
|
|
||||||
logger.info("Loaded pretrained vectors of size %s" % input_dim)
|
|
||||||
else:
|
|
||||||
raise ValueError(
|
|
||||||
"The `nlp` object should have access to pretrained word vectors, "
|
|
||||||
" cf. https://spacy.io/usage/models#languages."
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.info("Filtering entities with fewer than {} mentions".format(min_entity_freq))
|
|
||||||
entity_frequencies = io.read_entity_to_count(entity_freq_path)
|
|
||||||
# filter the entities for in the KB by frequency, because there's just too much data (8M entities) otherwise
|
|
||||||
filtered_title_to_id, entity_list, description_list, frequency_list = get_filtered_entities(
|
|
||||||
title_to_id,
|
|
||||||
id_to_descr,
|
|
||||||
entity_frequencies,
|
|
||||||
min_entity_freq
|
|
||||||
)
|
|
||||||
logger.info("Kept {} entities from the set of {}".format(len(description_list), len(title_to_id.keys())))
|
|
||||||
|
|
||||||
logger.info("Training entity encoder")
|
|
||||||
encoder = EntityEncoder(nlp, input_dim, entity_vector_length)
|
|
||||||
encoder.train(description_list=description_list, to_print=True)
|
|
||||||
|
|
||||||
logger.info("Getting entity embeddings")
|
|
||||||
embeddings = encoder.apply_encoder(description_list)
|
|
||||||
|
|
||||||
logger.info("Adding {} entities".format(len(entity_list)))
|
|
||||||
kb.set_entities(
|
|
||||||
entity_list=entity_list, freq_list=frequency_list, vector_list=embeddings
|
|
||||||
)
|
|
||||||
return entity_list, filtered_title_to_id
|
|
||||||
|
|
||||||
|
|
||||||
def _define_aliases(kb, entity_alias_path, entity_list, filtered_title_to_id, max_entities_per_alias, min_occ, prior_prob_path):
|
|
||||||
logger.info("Adding aliases from Wikipedia and Wikidata")
|
|
||||||
_add_aliases(
|
|
||||||
kb,
|
|
||||||
entity_list=entity_list,
|
|
||||||
title_to_id=filtered_title_to_id,
|
|
||||||
max_entities_per_alias=max_entities_per_alias,
|
|
||||||
min_occ=min_occ,
|
|
||||||
prior_prob_path=prior_prob_path,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def get_filtered_entities(title_to_id, id_to_descr, entity_frequencies,
|
|
||||||
min_entity_freq: int = 10):
|
|
||||||
filtered_title_to_id = dict()
|
|
||||||
entity_list = []
|
|
||||||
description_list = []
|
|
||||||
frequency_list = []
|
|
||||||
for title, entity in title_to_id.items():
|
|
||||||
freq = entity_frequencies.get(title, 0)
|
|
||||||
desc = id_to_descr.get(entity, None)
|
|
||||||
if desc and freq > min_entity_freq:
|
|
||||||
entity_list.append(entity)
|
|
||||||
description_list.append(desc)
|
|
||||||
frequency_list.append(freq)
|
|
||||||
filtered_title_to_id[title] = entity
|
|
||||||
return filtered_title_to_id, entity_list, description_list, frequency_list
|
|
||||||
|
|
||||||
|
|
||||||
def _add_aliases(kb, entity_list, title_to_id, max_entities_per_alias, min_occ, prior_prob_path):
|
|
||||||
wp_titles = title_to_id.keys()
|
|
||||||
|
|
||||||
# adding aliases with prior probabilities
|
|
||||||
# we can read this file sequentially, it's sorted by alias, and then by count
|
|
||||||
logger.info("Adding WP aliases")
|
|
||||||
with prior_prob_path.open("r", encoding="utf8") as prior_file:
|
|
||||||
# skip header
|
|
||||||
prior_file.readline()
|
|
||||||
line = prior_file.readline()
|
|
||||||
previous_alias = None
|
|
||||||
total_count = 0
|
|
||||||
counts = []
|
|
||||||
entities = []
|
|
||||||
while line:
|
|
||||||
splits = line.replace("\n", "").split(sep="|")
|
|
||||||
new_alias = splits[0]
|
|
||||||
count = int(splits[1])
|
|
||||||
entity = splits[2]
|
|
||||||
|
|
||||||
if new_alias != previous_alias and previous_alias:
|
|
||||||
# done reading the previous alias --> output
|
|
||||||
if len(entities) > 0:
|
|
||||||
selected_entities = []
|
|
||||||
prior_probs = []
|
|
||||||
for ent_count, ent_string in zip(counts, entities):
|
|
||||||
if ent_string in wp_titles:
|
|
||||||
wd_id = title_to_id[ent_string]
|
|
||||||
p_entity_givenalias = ent_count / total_count
|
|
||||||
selected_entities.append(wd_id)
|
|
||||||
prior_probs.append(p_entity_givenalias)
|
|
||||||
|
|
||||||
if selected_entities:
|
|
||||||
try:
|
|
||||||
kb.add_alias(
|
|
||||||
alias=previous_alias,
|
|
||||||
entities=selected_entities,
|
|
||||||
probabilities=prior_probs,
|
|
||||||
)
|
|
||||||
except ValueError as e:
|
|
||||||
logger.error(e)
|
|
||||||
total_count = 0
|
|
||||||
counts = []
|
|
||||||
entities = []
|
|
||||||
|
|
||||||
total_count += count
|
|
||||||
|
|
||||||
if len(entities) < max_entities_per_alias and count >= min_occ:
|
|
||||||
counts.append(count)
|
|
||||||
entities.append(entity)
|
|
||||||
previous_alias = new_alias
|
|
||||||
|
|
||||||
line = prior_file.readline()
|
|
||||||
|
|
||||||
|
|
||||||
def read_kb(nlp, kb_file):
|
|
||||||
kb = KnowledgeBase(vocab=nlp.vocab)
|
|
||||||
kb.load_bulk(kb_file)
|
|
||||||
return kb
|
|
|
@ -1,152 +0,0 @@
|
||||||
# coding: utf-8
|
|
||||||
from random import shuffle
|
|
||||||
|
|
||||||
import logging
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
from spacy._ml import zero_init, create_default_optimizer
|
|
||||||
from spacy.cli.pretrain import get_cossim_loss
|
|
||||||
|
|
||||||
from thinc.v2v import Model
|
|
||||||
from thinc.api import chain
|
|
||||||
from thinc.neural._classes.affine import Affine
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class EntityEncoder:
|
|
||||||
"""
|
|
||||||
Train the embeddings of entity descriptions to fit a fixed-size entity vector (e.g. 64D).
|
|
||||||
This entity vector will be stored in the KB, for further downstream use in the entity model.
|
|
||||||
"""
|
|
||||||
|
|
||||||
DROP = 0
|
|
||||||
BATCH_SIZE = 1000
|
|
||||||
|
|
||||||
# Set min. acceptable loss to avoid a 'mean of empty slice' warning by numpy
|
|
||||||
MIN_LOSS = 0.01
|
|
||||||
|
|
||||||
# Reasonable default to stop training when things are not improving
|
|
||||||
MAX_NO_IMPROVEMENT = 20
|
|
||||||
|
|
||||||
def __init__(self, nlp, input_dim, desc_width, epochs=5):
|
|
||||||
self.nlp = nlp
|
|
||||||
self.input_dim = input_dim
|
|
||||||
self.desc_width = desc_width
|
|
||||||
self.epochs = epochs
|
|
||||||
|
|
||||||
def apply_encoder(self, description_list):
|
|
||||||
if self.encoder is None:
|
|
||||||
raise ValueError("Can not apply encoder before training it")
|
|
||||||
|
|
||||||
batch_size = 100000
|
|
||||||
|
|
||||||
start = 0
|
|
||||||
stop = min(batch_size, len(description_list))
|
|
||||||
encodings = []
|
|
||||||
|
|
||||||
while start < len(description_list):
|
|
||||||
docs = list(self.nlp.pipe(description_list[start:stop]))
|
|
||||||
doc_embeddings = [self._get_doc_embedding(doc) for doc in docs]
|
|
||||||
enc = self.encoder(np.asarray(doc_embeddings))
|
|
||||||
encodings.extend(enc.tolist())
|
|
||||||
|
|
||||||
start = start + batch_size
|
|
||||||
stop = min(stop + batch_size, len(description_list))
|
|
||||||
logger.info("Encoded: {} entities".format(stop))
|
|
||||||
|
|
||||||
return encodings
|
|
||||||
|
|
||||||
def train(self, description_list, to_print=False):
|
|
||||||
processed, loss = self._train_model(description_list)
|
|
||||||
if to_print:
|
|
||||||
logger.info(
|
|
||||||
"Trained entity descriptions on {} ".format(processed) +
|
|
||||||
"(non-unique) descriptions across {} ".format(self.epochs) +
|
|
||||||
"epochs"
|
|
||||||
)
|
|
||||||
logger.info("Final loss: {}".format(loss))
|
|
||||||
|
|
||||||
def _train_model(self, description_list):
|
|
||||||
best_loss = 1.0
|
|
||||||
iter_since_best = 0
|
|
||||||
self._build_network(self.input_dim, self.desc_width)
|
|
||||||
|
|
||||||
processed = 0
|
|
||||||
loss = 1
|
|
||||||
# copy this list so that shuffling does not affect other functions
|
|
||||||
descriptions = description_list.copy()
|
|
||||||
to_continue = True
|
|
||||||
|
|
||||||
for i in range(self.epochs):
|
|
||||||
shuffle(descriptions)
|
|
||||||
|
|
||||||
batch_nr = 0
|
|
||||||
start = 0
|
|
||||||
stop = min(self.BATCH_SIZE, len(descriptions))
|
|
||||||
|
|
||||||
while to_continue and start < len(descriptions):
|
|
||||||
batch = []
|
|
||||||
for descr in descriptions[start:stop]:
|
|
||||||
doc = self.nlp(descr)
|
|
||||||
doc_vector = self._get_doc_embedding(doc)
|
|
||||||
batch.append(doc_vector)
|
|
||||||
|
|
||||||
loss = self._update(batch)
|
|
||||||
if batch_nr % 25 == 0:
|
|
||||||
logger.info("loss: {} ".format(loss))
|
|
||||||
processed += len(batch)
|
|
||||||
|
|
||||||
# in general, continue training if we haven't reached our ideal min yet
|
|
||||||
to_continue = loss > self.MIN_LOSS
|
|
||||||
|
|
||||||
# store the best loss and track how long it's been
|
|
||||||
if loss < best_loss:
|
|
||||||
best_loss = loss
|
|
||||||
iter_since_best = 0
|
|
||||||
else:
|
|
||||||
iter_since_best += 1
|
|
||||||
|
|
||||||
# stop learning if we haven't seen improvement since the last few iterations
|
|
||||||
if iter_since_best > self.MAX_NO_IMPROVEMENT:
|
|
||||||
to_continue = False
|
|
||||||
|
|
||||||
batch_nr += 1
|
|
||||||
start = start + self.BATCH_SIZE
|
|
||||||
stop = min(stop + self.BATCH_SIZE, len(descriptions))
|
|
||||||
|
|
||||||
return processed, loss
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _get_doc_embedding(doc):
|
|
||||||
indices = np.zeros((len(doc),), dtype="i")
|
|
||||||
for i, word in enumerate(doc):
|
|
||||||
if word.orth in doc.vocab.vectors.key2row:
|
|
||||||
indices[i] = doc.vocab.vectors.key2row[word.orth]
|
|
||||||
else:
|
|
||||||
indices[i] = 0
|
|
||||||
word_vectors = doc.vocab.vectors.data[indices]
|
|
||||||
doc_vector = np.mean(word_vectors, axis=0)
|
|
||||||
return doc_vector
|
|
||||||
|
|
||||||
def _build_network(self, orig_width, hidden_with):
|
|
||||||
with Model.define_operators({">>": chain}):
|
|
||||||
# very simple encoder-decoder model
|
|
||||||
self.encoder = Affine(hidden_with, orig_width)
|
|
||||||
self.model = self.encoder >> zero_init(
|
|
||||||
Affine(orig_width, hidden_with, drop_factor=0.0)
|
|
||||||
)
|
|
||||||
self.sgd = create_default_optimizer(self.model.ops)
|
|
||||||
|
|
||||||
def _update(self, vectors):
|
|
||||||
predictions, bp_model = self.model.begin_update(
|
|
||||||
np.asarray(vectors), drop=self.DROP
|
|
||||||
)
|
|
||||||
loss, d_scores = self._get_loss(scores=predictions, golds=np.asarray(vectors))
|
|
||||||
bp_model(d_scores, sgd=self.sgd)
|
|
||||||
return loss / len(vectors)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _get_loss(golds, scores):
|
|
||||||
loss, gradients = get_cossim_loss(scores, golds)
|
|
||||||
return loss, gradients
|
|
|
@ -1,127 +0,0 @@
|
||||||
# coding: utf-8
|
|
||||||
from __future__ import unicode_literals
|
|
||||||
|
|
||||||
import sys
|
|
||||||
import csv
|
|
||||||
|
|
||||||
# min() needed to prevent error on windows, cf https://stackoverflow.com/questions/52404416/
|
|
||||||
csv.field_size_limit(min(sys.maxsize, 2147483646))
|
|
||||||
|
|
||||||
""" This class provides reading/writing methods for temp files """
|
|
||||||
|
|
||||||
|
|
||||||
# Entity definition: WP title -> WD ID #
|
|
||||||
def write_title_to_id(entity_def_output, title_to_id):
|
|
||||||
with entity_def_output.open("w", encoding="utf8") as id_file:
|
|
||||||
id_file.write("WP_title" + "|" + "WD_id" + "\n")
|
|
||||||
for title, qid in title_to_id.items():
|
|
||||||
id_file.write(title + "|" + str(qid) + "\n")
|
|
||||||
|
|
||||||
|
|
||||||
def read_title_to_id(entity_def_output):
|
|
||||||
title_to_id = dict()
|
|
||||||
with entity_def_output.open("r", encoding="utf8") as id_file:
|
|
||||||
csvreader = csv.reader(id_file, delimiter="|")
|
|
||||||
# skip header
|
|
||||||
next(csvreader)
|
|
||||||
for row in csvreader:
|
|
||||||
title_to_id[row[0]] = row[1]
|
|
||||||
return title_to_id
|
|
||||||
|
|
||||||
|
|
||||||
# Entity aliases from WD: WD ID -> WD alias #
|
|
||||||
def write_id_to_alias(entity_alias_path, id_to_alias):
|
|
||||||
with entity_alias_path.open("w", encoding="utf8") as alias_file:
|
|
||||||
alias_file.write("WD_id" + "|" + "alias" + "\n")
|
|
||||||
for qid, alias_list in id_to_alias.items():
|
|
||||||
for alias in alias_list:
|
|
||||||
alias_file.write(str(qid) + "|" + alias + "\n")
|
|
||||||
|
|
||||||
|
|
||||||
def read_id_to_alias(entity_alias_path):
|
|
||||||
id_to_alias = dict()
|
|
||||||
with entity_alias_path.open("r", encoding="utf8") as alias_file:
|
|
||||||
csvreader = csv.reader(alias_file, delimiter="|")
|
|
||||||
# skip header
|
|
||||||
next(csvreader)
|
|
||||||
for row in csvreader:
|
|
||||||
qid = row[0]
|
|
||||||
alias = row[1]
|
|
||||||
alias_list = id_to_alias.get(qid, [])
|
|
||||||
alias_list.append(alias)
|
|
||||||
id_to_alias[qid] = alias_list
|
|
||||||
return id_to_alias
|
|
||||||
|
|
||||||
|
|
||||||
def read_alias_to_id_generator(entity_alias_path):
|
|
||||||
""" Read (aliases, qid) tuples """
|
|
||||||
|
|
||||||
with entity_alias_path.open("r", encoding="utf8") as alias_file:
|
|
||||||
csvreader = csv.reader(alias_file, delimiter="|")
|
|
||||||
# skip header
|
|
||||||
next(csvreader)
|
|
||||||
for row in csvreader:
|
|
||||||
qid = row[0]
|
|
||||||
alias = row[1]
|
|
||||||
yield alias, qid
|
|
||||||
|
|
||||||
|
|
||||||
# Entity descriptions from WD: WD ID -> WD alias #
|
|
||||||
def write_id_to_descr(entity_descr_output, id_to_descr):
|
|
||||||
with entity_descr_output.open("w", encoding="utf8") as descr_file:
|
|
||||||
descr_file.write("WD_id" + "|" + "description" + "\n")
|
|
||||||
for qid, descr in id_to_descr.items():
|
|
||||||
descr_file.write(str(qid) + "|" + descr + "\n")
|
|
||||||
|
|
||||||
|
|
||||||
def read_id_to_descr(entity_desc_path):
|
|
||||||
id_to_desc = dict()
|
|
||||||
with entity_desc_path.open("r", encoding="utf8") as descr_file:
|
|
||||||
csvreader = csv.reader(descr_file, delimiter="|")
|
|
||||||
# skip header
|
|
||||||
next(csvreader)
|
|
||||||
for row in csvreader:
|
|
||||||
id_to_desc[row[0]] = row[1]
|
|
||||||
return id_to_desc
|
|
||||||
|
|
||||||
|
|
||||||
# Entity counts from WP: WP title -> count #
|
|
||||||
def write_entity_to_count(prior_prob_input, count_output):
|
|
||||||
# Write entity counts for quick access later
|
|
||||||
entity_to_count = dict()
|
|
||||||
total_count = 0
|
|
||||||
|
|
||||||
with prior_prob_input.open("r", encoding="utf8") as prior_file:
|
|
||||||
# skip header
|
|
||||||
prior_file.readline()
|
|
||||||
line = prior_file.readline()
|
|
||||||
|
|
||||||
while line:
|
|
||||||
splits = line.replace("\n", "").split(sep="|")
|
|
||||||
# alias = splits[0]
|
|
||||||
count = int(splits[1])
|
|
||||||
entity = splits[2]
|
|
||||||
|
|
||||||
current_count = entity_to_count.get(entity, 0)
|
|
||||||
entity_to_count[entity] = current_count + count
|
|
||||||
|
|
||||||
total_count += count
|
|
||||||
|
|
||||||
line = prior_file.readline()
|
|
||||||
|
|
||||||
with count_output.open("w", encoding="utf8") as entity_file:
|
|
||||||
entity_file.write("entity" + "|" + "count" + "\n")
|
|
||||||
for entity, count in entity_to_count.items():
|
|
||||||
entity_file.write(entity + "|" + str(count) + "\n")
|
|
||||||
|
|
||||||
|
|
||||||
def read_entity_to_count(count_input):
|
|
||||||
entity_to_count = dict()
|
|
||||||
with count_input.open("r", encoding="utf8") as csvfile:
|
|
||||||
csvreader = csv.reader(csvfile, delimiter="|")
|
|
||||||
# skip header
|
|
||||||
next(csvreader)
|
|
||||||
for row in csvreader:
|
|
||||||
entity_to_count[row[0]] = int(row[1])
|
|
||||||
|
|
||||||
return entity_to_count
|
|
|
@ -1,128 +0,0 @@
|
||||||
# coding: utf8
|
|
||||||
from __future__ import unicode_literals
|
|
||||||
|
|
||||||
# List of meta pages in Wikidata, should be kept out of the Knowledge base
|
|
||||||
WD_META_ITEMS = [
|
|
||||||
"Q163875",
|
|
||||||
"Q191780",
|
|
||||||
"Q224414",
|
|
||||||
"Q4167836",
|
|
||||||
"Q4167410",
|
|
||||||
"Q4663903",
|
|
||||||
"Q11266439",
|
|
||||||
"Q13406463",
|
|
||||||
"Q15407973",
|
|
||||||
"Q18616576",
|
|
||||||
"Q19887878",
|
|
||||||
"Q22808320",
|
|
||||||
"Q23894233",
|
|
||||||
"Q33120876",
|
|
||||||
"Q42104522",
|
|
||||||
"Q47460393",
|
|
||||||
"Q64875536",
|
|
||||||
"Q66480449",
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
# TODO: add more cases from non-English WP's
|
|
||||||
|
|
||||||
# List of prefixes that refer to Wikipedia "file" pages
|
|
||||||
WP_FILE_NAMESPACE = ["Bestand", "File"]
|
|
||||||
|
|
||||||
# List of prefixes that refer to Wikipedia "category" pages
|
|
||||||
WP_CATEGORY_NAMESPACE = ["Kategori", "Category", "Categorie"]
|
|
||||||
|
|
||||||
# List of prefixes that refer to Wikipedia "meta" pages
|
|
||||||
# these will/should be matched ignoring case
|
|
||||||
WP_META_NAMESPACE = (
|
|
||||||
WP_FILE_NAMESPACE
|
|
||||||
+ WP_CATEGORY_NAMESPACE
|
|
||||||
+ [
|
|
||||||
"b",
|
|
||||||
"betawikiversity",
|
|
||||||
"Book",
|
|
||||||
"c",
|
|
||||||
"Commons",
|
|
||||||
"d",
|
|
||||||
"dbdump",
|
|
||||||
"download",
|
|
||||||
"Draft",
|
|
||||||
"Education",
|
|
||||||
"Foundation",
|
|
||||||
"Gadget",
|
|
||||||
"Gadget definition",
|
|
||||||
"Gebruiker",
|
|
||||||
"gerrit",
|
|
||||||
"Help",
|
|
||||||
"Image",
|
|
||||||
"Incubator",
|
|
||||||
"m",
|
|
||||||
"mail",
|
|
||||||
"mailarchive",
|
|
||||||
"media",
|
|
||||||
"MediaWiki",
|
|
||||||
"MediaWiki talk",
|
|
||||||
"Mediawikiwiki",
|
|
||||||
"MediaZilla",
|
|
||||||
"Meta",
|
|
||||||
"Metawikipedia",
|
|
||||||
"Module",
|
|
||||||
"mw",
|
|
||||||
"n",
|
|
||||||
"nost",
|
|
||||||
"oldwikisource",
|
|
||||||
"otrs",
|
|
||||||
"OTRSwiki",
|
|
||||||
"Overleg gebruiker",
|
|
||||||
"outreach",
|
|
||||||
"outreachwiki",
|
|
||||||
"Portal",
|
|
||||||
"phab",
|
|
||||||
"Phabricator",
|
|
||||||
"Project",
|
|
||||||
"q",
|
|
||||||
"quality",
|
|
||||||
"rev",
|
|
||||||
"s",
|
|
||||||
"spcom",
|
|
||||||
"Special",
|
|
||||||
"species",
|
|
||||||
"Strategy",
|
|
||||||
"sulutil",
|
|
||||||
"svn",
|
|
||||||
"Talk",
|
|
||||||
"Template",
|
|
||||||
"Template talk",
|
|
||||||
"Testwiki",
|
|
||||||
"ticket",
|
|
||||||
"TimedText",
|
|
||||||
"Toollabs",
|
|
||||||
"tools",
|
|
||||||
"tswiki",
|
|
||||||
"User",
|
|
||||||
"User talk",
|
|
||||||
"v",
|
|
||||||
"voy",
|
|
||||||
"w",
|
|
||||||
"Wikibooks",
|
|
||||||
"Wikidata",
|
|
||||||
"wikiHow",
|
|
||||||
"Wikinvest",
|
|
||||||
"wikilivres",
|
|
||||||
"Wikimedia",
|
|
||||||
"Wikinews",
|
|
||||||
"Wikipedia",
|
|
||||||
"Wikipedia talk",
|
|
||||||
"Wikiquote",
|
|
||||||
"Wikisource",
|
|
||||||
"Wikispecies",
|
|
||||||
"Wikitech",
|
|
||||||
"Wikiversity",
|
|
||||||
"Wikivoyage",
|
|
||||||
"wikt",
|
|
||||||
"wiktionary",
|
|
||||||
"wmf",
|
|
||||||
"wmania",
|
|
||||||
"WP",
|
|
||||||
]
|
|
||||||
)
|
|
|
@ -1,179 +0,0 @@
|
||||||
# coding: utf-8
|
|
||||||
"""Script to process Wikipedia and Wikidata dumps and create a knowledge base (KB)
|
|
||||||
with specific parameters. Intermediate files are written to disk.
|
|
||||||
|
|
||||||
Running the full pipeline on a standard laptop, may take up to 13 hours of processing.
|
|
||||||
Use the -p, -d and -s options to speed up processing using the intermediate files
|
|
||||||
from a previous run.
|
|
||||||
|
|
||||||
For the Wikidata dump: get the latest-all.json.bz2 from https://dumps.wikimedia.org/wikidatawiki/entities/
|
|
||||||
For the Wikipedia dump: get enwiki-latest-pages-articles-multistream.xml.bz2
|
|
||||||
from https://dumps.wikimedia.org/enwiki/latest/
|
|
||||||
|
|
||||||
"""
|
|
||||||
from __future__ import unicode_literals
|
|
||||||
|
|
||||||
import logging
|
|
||||||
from pathlib import Path
|
|
||||||
import plac
|
|
||||||
|
|
||||||
from bin.wiki_entity_linking import wikipedia_processor as wp, wikidata_processor as wd
|
|
||||||
from bin.wiki_entity_linking import wiki_io as io
|
|
||||||
from bin.wiki_entity_linking import kb_creator
|
|
||||||
from bin.wiki_entity_linking import TRAINING_DATA_FILE, KB_FILE, ENTITY_DESCR_PATH, KB_MODEL_DIR, LOG_FORMAT
|
|
||||||
from bin.wiki_entity_linking import ENTITY_FREQ_PATH, PRIOR_PROB_PATH, ENTITY_DEFS_PATH, ENTITY_ALIAS_PATH
|
|
||||||
import spacy
|
|
||||||
from bin.wiki_entity_linking.kb_creator import read_kb
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
@plac.annotations(
|
|
||||||
wd_json=("Path to the downloaded WikiData JSON dump.", "positional", None, Path),
|
|
||||||
wp_xml=("Path to the downloaded Wikipedia XML dump.", "positional", None, Path),
|
|
||||||
output_dir=("Output directory", "positional", None, Path),
|
|
||||||
model=("Model name or path, should include pretrained vectors.", "positional", None, str),
|
|
||||||
max_per_alias=("Max. # entities per alias (default 10)", "option", "a", int),
|
|
||||||
min_freq=("Min. count of an entity in the corpus (default 20)", "option", "f", int),
|
|
||||||
min_pair=("Min. count of entity-alias pairs (default 5)", "option", "c", int),
|
|
||||||
entity_vector_length=("Length of entity vectors (default 64)", "option", "v", int),
|
|
||||||
loc_prior_prob=("Location to file with prior probabilities", "option", "p", Path),
|
|
||||||
loc_entity_defs=("Location to file with entity definitions", "option", "d", Path),
|
|
||||||
loc_entity_desc=("Location to file with entity descriptions", "option", "s", Path),
|
|
||||||
descr_from_wp=("Flag for using descriptions from WP instead of WD (default False)", "flag", "wp"),
|
|
||||||
limit_prior=("Threshold to limit lines read from WP for prior probabilities", "option", "lp", int),
|
|
||||||
limit_train=("Threshold to limit lines read from WP for training set", "option", "lt", int),
|
|
||||||
limit_wd=("Threshold to limit lines read from WD", "option", "lw", int),
|
|
||||||
lang=("Optional language for which to get Wikidata titles. Defaults to 'en'", "option", "la", str),
|
|
||||||
)
|
|
||||||
def main(
|
|
||||||
wd_json,
|
|
||||||
wp_xml,
|
|
||||||
output_dir,
|
|
||||||
model,
|
|
||||||
max_per_alias=10,
|
|
||||||
min_freq=20,
|
|
||||||
min_pair=5,
|
|
||||||
entity_vector_length=64,
|
|
||||||
loc_prior_prob=None,
|
|
||||||
loc_entity_defs=None,
|
|
||||||
loc_entity_alias=None,
|
|
||||||
loc_entity_desc=None,
|
|
||||||
descr_from_wp=False,
|
|
||||||
limit_prior=None,
|
|
||||||
limit_train=None,
|
|
||||||
limit_wd=None,
|
|
||||||
lang="en",
|
|
||||||
):
|
|
||||||
entity_defs_path = loc_entity_defs if loc_entity_defs else output_dir / ENTITY_DEFS_PATH
|
|
||||||
entity_alias_path = loc_entity_alias if loc_entity_alias else output_dir / ENTITY_ALIAS_PATH
|
|
||||||
entity_descr_path = loc_entity_desc if loc_entity_desc else output_dir / ENTITY_DESCR_PATH
|
|
||||||
entity_freq_path = output_dir / ENTITY_FREQ_PATH
|
|
||||||
prior_prob_path = loc_prior_prob if loc_prior_prob else output_dir / PRIOR_PROB_PATH
|
|
||||||
training_entities_path = output_dir / TRAINING_DATA_FILE
|
|
||||||
kb_path = output_dir / KB_FILE
|
|
||||||
|
|
||||||
logger.info("Creating KB with Wikipedia and WikiData")
|
|
||||||
|
|
||||||
# STEP 0: set up IO
|
|
||||||
if not output_dir.exists():
|
|
||||||
output_dir.mkdir(parents=True)
|
|
||||||
|
|
||||||
# STEP 1: Load the NLP object
|
|
||||||
logger.info("STEP 1: Loading NLP model {}".format(model))
|
|
||||||
nlp = spacy.load(model)
|
|
||||||
|
|
||||||
# check the length of the nlp vectors
|
|
||||||
if "vectors" not in nlp.meta or not nlp.vocab.vectors.size:
|
|
||||||
raise ValueError(
|
|
||||||
"The `nlp` object should have access to pretrained word vectors, "
|
|
||||||
" cf. https://spacy.io/usage/models#languages."
|
|
||||||
)
|
|
||||||
|
|
||||||
# STEP 2: create prior probabilities from WP
|
|
||||||
if not prior_prob_path.exists():
|
|
||||||
# It takes about 2h to process 1000M lines of Wikipedia XML dump
|
|
||||||
logger.info("STEP 2: Writing prior probabilities to {}".format(prior_prob_path))
|
|
||||||
if limit_prior is not None:
|
|
||||||
logger.warning("Warning: reading only {} lines of Wikipedia dump".format(limit_prior))
|
|
||||||
wp.read_prior_probs(wp_xml, prior_prob_path, limit=limit_prior)
|
|
||||||
else:
|
|
||||||
logger.info("STEP 2: Reading prior probabilities from {}".format(prior_prob_path))
|
|
||||||
|
|
||||||
# STEP 3: calculate entity frequencies
|
|
||||||
if not entity_freq_path.exists():
|
|
||||||
logger.info("STEP 3: Calculating and writing entity frequencies to {}".format(entity_freq_path))
|
|
||||||
io.write_entity_to_count(prior_prob_path, entity_freq_path)
|
|
||||||
else:
|
|
||||||
logger.info("STEP 3: Reading entity frequencies from {}".format(entity_freq_path))
|
|
||||||
|
|
||||||
# STEP 4: reading definitions and (possibly) descriptions from WikiData or from file
|
|
||||||
if (not entity_defs_path.exists()) or (not descr_from_wp and not entity_descr_path.exists()):
|
|
||||||
# It takes about 10h to process 55M lines of Wikidata JSON dump
|
|
||||||
logger.info("STEP 4: Parsing and writing Wikidata entity definitions to {}".format(entity_defs_path))
|
|
||||||
if limit_wd is not None:
|
|
||||||
logger.warning("Warning: reading only {} lines of Wikidata dump".format(limit_wd))
|
|
||||||
title_to_id, id_to_descr, id_to_alias = wd.read_wikidata_entities_json(
|
|
||||||
wd_json,
|
|
||||||
limit_wd,
|
|
||||||
to_print=False,
|
|
||||||
lang=lang,
|
|
||||||
parse_descr=(not descr_from_wp),
|
|
||||||
)
|
|
||||||
io.write_title_to_id(entity_defs_path, title_to_id)
|
|
||||||
|
|
||||||
logger.info("STEP 4b: Writing Wikidata entity aliases to {}".format(entity_alias_path))
|
|
||||||
io.write_id_to_alias(entity_alias_path, id_to_alias)
|
|
||||||
|
|
||||||
if not descr_from_wp:
|
|
||||||
logger.info("STEP 4c: Writing Wikidata entity descriptions to {}".format(entity_descr_path))
|
|
||||||
io.write_id_to_descr(entity_descr_path, id_to_descr)
|
|
||||||
else:
|
|
||||||
logger.info("STEP 4: Reading entity definitions from {}".format(entity_defs_path))
|
|
||||||
logger.info("STEP 4b: Reading entity aliases from {}".format(entity_alias_path))
|
|
||||||
if not descr_from_wp:
|
|
||||||
logger.info("STEP 4c: Reading entity descriptions from {}".format(entity_descr_path))
|
|
||||||
|
|
||||||
# STEP 5: Getting gold entities from Wikipedia
|
|
||||||
if (not training_entities_path.exists()) or (descr_from_wp and not entity_descr_path.exists()):
|
|
||||||
logger.info("STEP 5: Parsing and writing Wikipedia gold entities to {}".format(training_entities_path))
|
|
||||||
if limit_train is not None:
|
|
||||||
logger.warning("Warning: reading only {} lines of Wikipedia dump".format(limit_train))
|
|
||||||
wp.create_training_and_desc(wp_xml, entity_defs_path, entity_descr_path,
|
|
||||||
training_entities_path, descr_from_wp, limit_train)
|
|
||||||
if descr_from_wp:
|
|
||||||
logger.info("STEP 5b: Parsing and writing Wikipedia descriptions to {}".format(entity_descr_path))
|
|
||||||
else:
|
|
||||||
logger.info("STEP 5: Reading gold entities from {}".format(training_entities_path))
|
|
||||||
if descr_from_wp:
|
|
||||||
logger.info("STEP 5b: Reading entity descriptions from {}".format(entity_descr_path))
|
|
||||||
|
|
||||||
# STEP 6: creating the actual KB
|
|
||||||
# It takes ca. 30 minutes to pretrain the entity embeddings
|
|
||||||
if not kb_path.exists():
|
|
||||||
logger.info("STEP 6: Creating the KB at {}".format(kb_path))
|
|
||||||
kb = kb_creator.create_kb(
|
|
||||||
nlp=nlp,
|
|
||||||
max_entities_per_alias=max_per_alias,
|
|
||||||
min_entity_freq=min_freq,
|
|
||||||
min_occ=min_pair,
|
|
||||||
entity_def_path=entity_defs_path,
|
|
||||||
entity_descr_path=entity_descr_path,
|
|
||||||
entity_alias_path=entity_alias_path,
|
|
||||||
entity_freq_path=entity_freq_path,
|
|
||||||
prior_prob_path=prior_prob_path,
|
|
||||||
entity_vector_length=entity_vector_length,
|
|
||||||
)
|
|
||||||
kb.dump(kb_path)
|
|
||||||
logger.info("kb entities: {}".format(kb.get_size_entities()))
|
|
||||||
logger.info("kb aliases: {}".format(kb.get_size_aliases()))
|
|
||||||
nlp.to_disk(output_dir / KB_MODEL_DIR)
|
|
||||||
else:
|
|
||||||
logger.info("STEP 6: KB already exists at {}".format(kb_path))
|
|
||||||
|
|
||||||
logger.info("Done!")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
logging.basicConfig(level=logging.INFO, format=LOG_FORMAT)
|
|
||||||
plac.call(main)
|
|
|
@ -1,154 +0,0 @@
|
||||||
# coding: utf-8
|
|
||||||
from __future__ import unicode_literals
|
|
||||||
|
|
||||||
import bz2
|
|
||||||
import json
|
|
||||||
import logging
|
|
||||||
|
|
||||||
from bin.wiki_entity_linking.wiki_namespaces import WD_META_ITEMS
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
def read_wikidata_entities_json(wikidata_file, limit=None, to_print=False, lang="en", parse_descr=True):
|
|
||||||
# Read the JSON wiki data and parse out the entities. Takes about 7-10h to parse 55M lines.
|
|
||||||
# get latest-all.json.bz2 from https://dumps.wikimedia.org/wikidatawiki/entities/
|
|
||||||
|
|
||||||
site_filter = '{}wiki'.format(lang)
|
|
||||||
|
|
||||||
# filter: currently defined as OR: one hit suffices to be removed from further processing
|
|
||||||
exclude_list = WD_META_ITEMS
|
|
||||||
|
|
||||||
# punctuation
|
|
||||||
exclude_list.extend(["Q1383557", "Q10617810"])
|
|
||||||
|
|
||||||
# letters etc
|
|
||||||
exclude_list.extend(["Q188725", "Q19776628", "Q3841820", "Q17907810", "Q9788", "Q9398093"])
|
|
||||||
|
|
||||||
neg_prop_filter = {
|
|
||||||
'P31': exclude_list, # instance of
|
|
||||||
'P279': exclude_list # subclass
|
|
||||||
}
|
|
||||||
|
|
||||||
title_to_id = dict()
|
|
||||||
id_to_descr = dict()
|
|
||||||
id_to_alias = dict()
|
|
||||||
|
|
||||||
# parse appropriate fields - depending on what we need in the KB
|
|
||||||
parse_properties = False
|
|
||||||
parse_sitelinks = True
|
|
||||||
parse_labels = False
|
|
||||||
parse_aliases = True
|
|
||||||
parse_claims = True
|
|
||||||
|
|
||||||
with bz2.open(wikidata_file, mode='rb') as file:
|
|
||||||
for cnt, line in enumerate(file):
|
|
||||||
if limit and cnt >= limit:
|
|
||||||
break
|
|
||||||
if cnt % 500000 == 0 and cnt > 0:
|
|
||||||
logger.info("processed {} lines of WikiData JSON dump".format(cnt))
|
|
||||||
clean_line = line.strip()
|
|
||||||
if clean_line.endswith(b","):
|
|
||||||
clean_line = clean_line[:-1]
|
|
||||||
if len(clean_line) > 1:
|
|
||||||
obj = json.loads(clean_line)
|
|
||||||
entry_type = obj["type"]
|
|
||||||
|
|
||||||
if entry_type == "item":
|
|
||||||
keep = True
|
|
||||||
|
|
||||||
claims = obj["claims"]
|
|
||||||
if parse_claims:
|
|
||||||
for prop, value_set in neg_prop_filter.items():
|
|
||||||
claim_property = claims.get(prop, None)
|
|
||||||
if claim_property:
|
|
||||||
for cp in claim_property:
|
|
||||||
cp_id = (
|
|
||||||
cp["mainsnak"]
|
|
||||||
.get("datavalue", {})
|
|
||||||
.get("value", {})
|
|
||||||
.get("id")
|
|
||||||
)
|
|
||||||
cp_rank = cp["rank"]
|
|
||||||
if cp_rank != "deprecated" and cp_id in value_set:
|
|
||||||
keep = False
|
|
||||||
|
|
||||||
if keep:
|
|
||||||
unique_id = obj["id"]
|
|
||||||
|
|
||||||
if to_print:
|
|
||||||
print("ID:", unique_id)
|
|
||||||
print("type:", entry_type)
|
|
||||||
|
|
||||||
# parsing all properties that refer to other entities
|
|
||||||
if parse_properties:
|
|
||||||
for prop, claim_property in claims.items():
|
|
||||||
cp_dicts = [
|
|
||||||
cp["mainsnak"]["datavalue"].get("value")
|
|
||||||
for cp in claim_property
|
|
||||||
if cp["mainsnak"].get("datavalue")
|
|
||||||
]
|
|
||||||
cp_values = [
|
|
||||||
cp_dict.get("id")
|
|
||||||
for cp_dict in cp_dicts
|
|
||||||
if isinstance(cp_dict, dict)
|
|
||||||
if cp_dict.get("id") is not None
|
|
||||||
]
|
|
||||||
if cp_values:
|
|
||||||
if to_print:
|
|
||||||
print("prop:", prop, cp_values)
|
|
||||||
|
|
||||||
found_link = False
|
|
||||||
if parse_sitelinks:
|
|
||||||
site_value = obj["sitelinks"].get(site_filter, None)
|
|
||||||
if site_value:
|
|
||||||
site = site_value["title"]
|
|
||||||
if to_print:
|
|
||||||
print(site_filter, ":", site)
|
|
||||||
title_to_id[site] = unique_id
|
|
||||||
found_link = True
|
|
||||||
|
|
||||||
if parse_labels:
|
|
||||||
labels = obj["labels"]
|
|
||||||
if labels:
|
|
||||||
lang_label = labels.get(lang, None)
|
|
||||||
if lang_label:
|
|
||||||
if to_print:
|
|
||||||
print(
|
|
||||||
"label (" + lang + "):", lang_label["value"]
|
|
||||||
)
|
|
||||||
|
|
||||||
if found_link and parse_descr:
|
|
||||||
descriptions = obj["descriptions"]
|
|
||||||
if descriptions:
|
|
||||||
lang_descr = descriptions.get(lang, None)
|
|
||||||
if lang_descr:
|
|
||||||
if to_print:
|
|
||||||
print(
|
|
||||||
"description (" + lang + "):",
|
|
||||||
lang_descr["value"],
|
|
||||||
)
|
|
||||||
id_to_descr[unique_id] = lang_descr["value"]
|
|
||||||
|
|
||||||
if parse_aliases:
|
|
||||||
aliases = obj["aliases"]
|
|
||||||
if aliases:
|
|
||||||
lang_aliases = aliases.get(lang, None)
|
|
||||||
if lang_aliases:
|
|
||||||
for item in lang_aliases:
|
|
||||||
if to_print:
|
|
||||||
print(
|
|
||||||
"alias (" + lang + "):", item["value"]
|
|
||||||
)
|
|
||||||
alias_list = id_to_alias.get(unique_id, [])
|
|
||||||
alias_list.append(item["value"])
|
|
||||||
id_to_alias[unique_id] = alias_list
|
|
||||||
|
|
||||||
if to_print:
|
|
||||||
print()
|
|
||||||
|
|
||||||
# log final number of lines processed
|
|
||||||
logger.info("Finished. Processed {} lines of WikiData JSON dump".format(cnt))
|
|
||||||
return title_to_id, id_to_descr, id_to_alias
|
|
||||||
|
|
||||||
|
|
|
@ -1,172 +0,0 @@
|
||||||
# coding: utf-8
|
|
||||||
"""Script that takes a previously created Knowledge Base and trains an entity linking
|
|
||||||
pipeline. The provided KB directory should hold the kb, the original nlp object and
|
|
||||||
its vocab used to create the KB, and a few auxiliary files such as the entity definitions,
|
|
||||||
as created by the script `wikidata_create_kb`.
|
|
||||||
|
|
||||||
For the Wikipedia dump: get enwiki-latest-pages-articles-multistream.xml.bz2
|
|
||||||
from https://dumps.wikimedia.org/enwiki/latest/
|
|
||||||
"""
|
|
||||||
from __future__ import unicode_literals
|
|
||||||
|
|
||||||
import random
|
|
||||||
import logging
|
|
||||||
import spacy
|
|
||||||
from pathlib import Path
|
|
||||||
import plac
|
|
||||||
from tqdm import tqdm
|
|
||||||
|
|
||||||
from bin.wiki_entity_linking import wikipedia_processor
|
|
||||||
from bin.wiki_entity_linking import TRAINING_DATA_FILE, KB_MODEL_DIR, KB_FILE, LOG_FORMAT, OUTPUT_MODEL_DIR
|
|
||||||
from bin.wiki_entity_linking.entity_linker_evaluation import measure_performance
|
|
||||||
from bin.wiki_entity_linking.kb_creator import read_kb
|
|
||||||
|
|
||||||
from spacy.util import minibatch, compounding
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
@plac.annotations(
|
|
||||||
dir_kb=("Directory with KB, NLP and related files", "positional", None, Path),
|
|
||||||
output_dir=("Output directory", "option", "o", Path),
|
|
||||||
loc_training=("Location to training data", "option", "k", Path),
|
|
||||||
epochs=("Number of training iterations (default 10)", "option", "e", int),
|
|
||||||
dropout=("Dropout to prevent overfitting (default 0.5)", "option", "p", float),
|
|
||||||
lr=("Learning rate (default 0.005)", "option", "n", float),
|
|
||||||
l2=("L2 regularization", "option", "r", float),
|
|
||||||
train_articles=("# training articles (default 90% of all)", "option", "t", int),
|
|
||||||
dev_articles=("# dev test articles (default 10% of all)", "option", "d", int),
|
|
||||||
labels_discard=("NER labels to discard (default None)", "option", "l", str),
|
|
||||||
)
|
|
||||||
def main(
|
|
||||||
dir_kb,
|
|
||||||
output_dir=None,
|
|
||||||
loc_training=None,
|
|
||||||
epochs=10,
|
|
||||||
dropout=0.5,
|
|
||||||
lr=0.005,
|
|
||||||
l2=1e-6,
|
|
||||||
train_articles=None,
|
|
||||||
dev_articles=None,
|
|
||||||
labels_discard=None
|
|
||||||
):
|
|
||||||
if not output_dir:
|
|
||||||
logger.warning("No output dir specified so no results will be written, are you sure about this ?")
|
|
||||||
|
|
||||||
logger.info("Creating Entity Linker with Wikipedia and WikiData")
|
|
||||||
|
|
||||||
output_dir = Path(output_dir) if output_dir else dir_kb
|
|
||||||
training_path = loc_training if loc_training else dir_kb / TRAINING_DATA_FILE
|
|
||||||
nlp_dir = dir_kb / KB_MODEL_DIR
|
|
||||||
kb_path = dir_kb / KB_FILE
|
|
||||||
nlp_output_dir = output_dir / OUTPUT_MODEL_DIR
|
|
||||||
|
|
||||||
# STEP 0: set up IO
|
|
||||||
if not output_dir.exists():
|
|
||||||
output_dir.mkdir()
|
|
||||||
|
|
||||||
# STEP 1 : load the NLP object
|
|
||||||
logger.info("STEP 1a: Loading model from {}".format(nlp_dir))
|
|
||||||
nlp = spacy.load(nlp_dir)
|
|
||||||
logger.info("Original NLP pipeline has following pipeline components: {}".format(nlp.pipe_names))
|
|
||||||
|
|
||||||
# check that there is a NER component in the pipeline
|
|
||||||
if "ner" not in nlp.pipe_names:
|
|
||||||
raise ValueError("The `nlp` object should have a pretrained `ner` component.")
|
|
||||||
|
|
||||||
logger.info("STEP 1b: Loading KB from {}".format(kb_path))
|
|
||||||
kb = read_kb(nlp, kb_path)
|
|
||||||
|
|
||||||
# STEP 2: read the training dataset previously created from WP
|
|
||||||
logger.info("STEP 2: Reading training & dev dataset from {}".format(training_path))
|
|
||||||
train_indices, dev_indices = wikipedia_processor.read_training_indices(training_path)
|
|
||||||
logger.info("Training set has {} articles, limit set to roughly {} articles per epoch"
|
|
||||||
.format(len(train_indices), train_articles if train_articles else "all"))
|
|
||||||
logger.info("Dev set has {} articles, limit set to rougly {} articles for evaluation"
|
|
||||||
.format(len(dev_indices), dev_articles if dev_articles else "all"))
|
|
||||||
if dev_articles:
|
|
||||||
dev_indices = dev_indices[0:dev_articles]
|
|
||||||
|
|
||||||
# STEP 3: create and train an entity linking pipe
|
|
||||||
logger.info("STEP 3: Creating and training an Entity Linking pipe for {} epochs".format(epochs))
|
|
||||||
if labels_discard:
|
|
||||||
labels_discard = [x.strip() for x in labels_discard.split(",")]
|
|
||||||
logger.info("Discarding {} NER types: {}".format(len(labels_discard), labels_discard))
|
|
||||||
else:
|
|
||||||
labels_discard = []
|
|
||||||
|
|
||||||
el_pipe = nlp.create_pipe(
|
|
||||||
name="entity_linker", config={"pretrained_vectors": nlp.vocab.vectors.name,
|
|
||||||
"labels_discard": labels_discard}
|
|
||||||
)
|
|
||||||
el_pipe.set_kb(kb)
|
|
||||||
nlp.add_pipe(el_pipe, last=True)
|
|
||||||
|
|
||||||
other_pipes = [pipe for pipe in nlp.pipe_names if pipe != "entity_linker"]
|
|
||||||
with nlp.disable_pipes(*other_pipes): # only train Entity Linking
|
|
||||||
optimizer = nlp.begin_training()
|
|
||||||
optimizer.learn_rate = lr
|
|
||||||
optimizer.L2 = l2
|
|
||||||
|
|
||||||
logger.info("Dev Baseline Accuracies:")
|
|
||||||
dev_data = wikipedia_processor.read_el_docs_golds(nlp=nlp, entity_file_path=training_path,
|
|
||||||
dev=True, line_ids=dev_indices,
|
|
||||||
kb=kb, labels_discard=labels_discard)
|
|
||||||
|
|
||||||
measure_performance(dev_data, kb, el_pipe, baseline=True, context=False, dev_limit=len(dev_indices))
|
|
||||||
|
|
||||||
for itn in range(epochs):
|
|
||||||
random.shuffle(train_indices)
|
|
||||||
losses = {}
|
|
||||||
batches = minibatch(train_indices, size=compounding(8.0, 128.0, 1.001))
|
|
||||||
batchnr = 0
|
|
||||||
articles_processed = 0
|
|
||||||
|
|
||||||
# we either process the whole training file, or just a part each epoch
|
|
||||||
bar_total = len(train_indices)
|
|
||||||
if train_articles:
|
|
||||||
bar_total = train_articles
|
|
||||||
|
|
||||||
with tqdm(total=bar_total, leave=False, desc='Epoch ' + str(itn)) as pbar:
|
|
||||||
for batch in batches:
|
|
||||||
if not train_articles or articles_processed < train_articles:
|
|
||||||
with nlp.disable_pipes("entity_linker"):
|
|
||||||
train_batch = wikipedia_processor.read_el_docs_golds(nlp=nlp, entity_file_path=training_path,
|
|
||||||
dev=False, line_ids=batch,
|
|
||||||
kb=kb, labels_discard=labels_discard)
|
|
||||||
docs, golds = zip(*train_batch)
|
|
||||||
try:
|
|
||||||
with nlp.disable_pipes(*other_pipes):
|
|
||||||
nlp.update(
|
|
||||||
docs=docs,
|
|
||||||
golds=golds,
|
|
||||||
sgd=optimizer,
|
|
||||||
drop=dropout,
|
|
||||||
losses=losses,
|
|
||||||
)
|
|
||||||
batchnr += 1
|
|
||||||
articles_processed += len(docs)
|
|
||||||
pbar.update(len(docs))
|
|
||||||
except Exception as e:
|
|
||||||
logger.error("Error updating batch:" + str(e))
|
|
||||||
if batchnr > 0:
|
|
||||||
logging.info("Epoch {} trained on {} articles, train loss {}"
|
|
||||||
.format(itn, articles_processed, round(losses["entity_linker"] / batchnr, 2)))
|
|
||||||
# re-read the dev_data (data is returned as a generator)
|
|
||||||
dev_data = wikipedia_processor.read_el_docs_golds(nlp=nlp, entity_file_path=training_path,
|
|
||||||
dev=True, line_ids=dev_indices,
|
|
||||||
kb=kb, labels_discard=labels_discard)
|
|
||||||
measure_performance(dev_data, kb, el_pipe, baseline=False, context=True, dev_limit=len(dev_indices))
|
|
||||||
|
|
||||||
if output_dir:
|
|
||||||
# STEP 4: write the NLP pipeline (now including an EL model) to file
|
|
||||||
logger.info("Final NLP pipeline has following pipeline components: {}".format(nlp.pipe_names))
|
|
||||||
logger.info("STEP 4: Writing trained NLP to {}".format(nlp_output_dir))
|
|
||||||
nlp.to_disk(nlp_output_dir)
|
|
||||||
|
|
||||||
logger.info("Done!")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
logging.basicConfig(level=logging.INFO, format=LOG_FORMAT)
|
|
||||||
plac.call(main)
|
|
|
@ -1,565 +0,0 @@
|
||||||
# coding: utf-8
|
|
||||||
from __future__ import unicode_literals
|
|
||||||
|
|
||||||
import re
|
|
||||||
import bz2
|
|
||||||
import logging
|
|
||||||
import random
|
|
||||||
import json
|
|
||||||
|
|
||||||
from spacy.gold import GoldParse
|
|
||||||
from bin.wiki_entity_linking import wiki_io as io
|
|
||||||
from bin.wiki_entity_linking.wiki_namespaces import (
|
|
||||||
WP_META_NAMESPACE,
|
|
||||||
WP_FILE_NAMESPACE,
|
|
||||||
WP_CATEGORY_NAMESPACE,
|
|
||||||
)
|
|
||||||
|
|
||||||
"""
|
|
||||||
Process a Wikipedia dump to calculate entity frequencies and prior probabilities in combination with certain mentions.
|
|
||||||
Write these results to file for downstream KB and training data generation.
|
|
||||||
|
|
||||||
Process Wikipedia interlinks to generate a training dataset for the EL algorithm.
|
|
||||||
"""
|
|
||||||
|
|
||||||
ENTITY_FILE = "gold_entities.csv"
|
|
||||||
|
|
||||||
map_alias_to_link = dict()
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
title_regex = re.compile(r"(?<=<title>).*(?=</title>)")
|
|
||||||
id_regex = re.compile(r"(?<=<id>)\d*(?=</id>)")
|
|
||||||
text_regex = re.compile(r"(?<=<text xml:space=\"preserve\">).*(?=</text)")
|
|
||||||
info_regex = re.compile(r"{[^{]*?}")
|
|
||||||
html_regex = re.compile(r"<!--[^-]*-->")
|
|
||||||
ref_regex = re.compile(r"<ref.*?>") # non-greedy
|
|
||||||
ref_2_regex = re.compile(r"</ref.*?>") # non-greedy
|
|
||||||
|
|
||||||
# find the links
|
|
||||||
link_regex = re.compile(r"\[\[[^\[\]]*\]\]")
|
|
||||||
|
|
||||||
# match on interwiki links, e.g. `en:` or `:fr:`
|
|
||||||
ns_regex = r":?" + "[a-z][a-z]" + ":"
|
|
||||||
# match on Namespace: optionally preceded by a :
|
|
||||||
for ns in WP_META_NAMESPACE:
|
|
||||||
ns_regex += "|" + ":?" + ns + ":"
|
|
||||||
ns_regex = re.compile(ns_regex, re.IGNORECASE)
|
|
||||||
|
|
||||||
files = r""
|
|
||||||
for f in WP_FILE_NAMESPACE:
|
|
||||||
files += "\[\[" + f + ":[^[\]]+]]" + "|"
|
|
||||||
files = files[0 : len(files) - 1]
|
|
||||||
file_regex = re.compile(files)
|
|
||||||
|
|
||||||
cats = r""
|
|
||||||
for c in WP_CATEGORY_NAMESPACE:
|
|
||||||
cats += "\[\[" + c + ":[^\[]*]]" + "|"
|
|
||||||
cats = cats[0 : len(cats) - 1]
|
|
||||||
category_regex = re.compile(cats)
|
|
||||||
|
|
||||||
|
|
||||||
def read_prior_probs(wikipedia_input, prior_prob_output, limit=None):
|
|
||||||
"""
|
|
||||||
Read the XML wikipedia data and parse out intra-wiki links to estimate prior probabilities.
|
|
||||||
The full file takes about 2-3h to parse 1100M lines.
|
|
||||||
It works relatively fast because it runs line by line, irrelevant of which article the intrawiki is from,
|
|
||||||
though dev test articles are excluded in order not to get an artificially strong baseline.
|
|
||||||
"""
|
|
||||||
cnt = 0
|
|
||||||
read_id = False
|
|
||||||
current_article_id = None
|
|
||||||
with bz2.open(wikipedia_input, mode="rb") as file:
|
|
||||||
line = file.readline()
|
|
||||||
while line and (not limit or cnt < limit):
|
|
||||||
if cnt % 25000000 == 0 and cnt > 0:
|
|
||||||
logger.info("processed {} lines of Wikipedia XML dump".format(cnt))
|
|
||||||
clean_line = line.strip().decode("utf-8")
|
|
||||||
|
|
||||||
# we attempt at reading the article's ID (but not the revision or contributor ID)
|
|
||||||
if "<revision>" in clean_line or "<contributor>" in clean_line:
|
|
||||||
read_id = False
|
|
||||||
if "<page>" in clean_line:
|
|
||||||
read_id = True
|
|
||||||
|
|
||||||
if read_id:
|
|
||||||
ids = id_regex.search(clean_line)
|
|
||||||
if ids:
|
|
||||||
current_article_id = ids[0]
|
|
||||||
|
|
||||||
# only processing prior probabilities from true training (non-dev) articles
|
|
||||||
if not is_dev(current_article_id):
|
|
||||||
aliases, entities, normalizations = get_wp_links(clean_line)
|
|
||||||
for alias, entity, norm in zip(aliases, entities, normalizations):
|
|
||||||
_store_alias(
|
|
||||||
alias, entity, normalize_alias=norm, normalize_entity=True
|
|
||||||
)
|
|
||||||
|
|
||||||
line = file.readline()
|
|
||||||
cnt += 1
|
|
||||||
logger.info("processed {} lines of Wikipedia XML dump".format(cnt))
|
|
||||||
logger.info("Finished. processed {} lines of Wikipedia XML dump".format(cnt))
|
|
||||||
|
|
||||||
# write all aliases and their entities and count occurrences to file
|
|
||||||
with prior_prob_output.open("w", encoding="utf8") as outputfile:
|
|
||||||
outputfile.write("alias" + "|" + "count" + "|" + "entity" + "\n")
|
|
||||||
for alias, alias_dict in sorted(map_alias_to_link.items(), key=lambda x: x[0]):
|
|
||||||
s_dict = sorted(alias_dict.items(), key=lambda x: x[1], reverse=True)
|
|
||||||
for entity, count in s_dict:
|
|
||||||
outputfile.write(alias + "|" + str(count) + "|" + entity + "\n")
|
|
||||||
|
|
||||||
|
|
||||||
def _store_alias(alias, entity, normalize_alias=False, normalize_entity=True):
|
|
||||||
alias = alias.strip()
|
|
||||||
entity = entity.strip()
|
|
||||||
|
|
||||||
# remove everything after # as this is not part of the title but refers to a specific paragraph
|
|
||||||
if normalize_entity:
|
|
||||||
# wikipedia titles are always capitalized
|
|
||||||
entity = _capitalize_first(entity.split("#")[0])
|
|
||||||
if normalize_alias:
|
|
||||||
alias = alias.split("#")[0]
|
|
||||||
|
|
||||||
if alias and entity:
|
|
||||||
alias_dict = map_alias_to_link.get(alias, dict())
|
|
||||||
entity_count = alias_dict.get(entity, 0)
|
|
||||||
alias_dict[entity] = entity_count + 1
|
|
||||||
map_alias_to_link[alias] = alias_dict
|
|
||||||
|
|
||||||
|
|
||||||
def get_wp_links(text):
|
|
||||||
aliases = []
|
|
||||||
entities = []
|
|
||||||
normalizations = []
|
|
||||||
|
|
||||||
matches = link_regex.findall(text)
|
|
||||||
for match in matches:
|
|
||||||
match = match[2:][:-2].replace("_", " ").strip()
|
|
||||||
|
|
||||||
if ns_regex.match(match):
|
|
||||||
pass # ignore the entity if it points to a "meta" page
|
|
||||||
|
|
||||||
# this is a simple [[link]], with the alias the same as the mention
|
|
||||||
elif "|" not in match:
|
|
||||||
aliases.append(match)
|
|
||||||
entities.append(match)
|
|
||||||
normalizations.append(True)
|
|
||||||
|
|
||||||
# in wiki format, the link is written as [[entity|alias]]
|
|
||||||
else:
|
|
||||||
splits = match.split("|")
|
|
||||||
entity = splits[0].strip()
|
|
||||||
alias = splits[1].strip()
|
|
||||||
# specific wiki format [[alias (specification)|]]
|
|
||||||
if len(alias) == 0 and "(" in entity:
|
|
||||||
alias = entity.split("(")[0]
|
|
||||||
aliases.append(alias)
|
|
||||||
entities.append(entity)
|
|
||||||
normalizations.append(False)
|
|
||||||
else:
|
|
||||||
aliases.append(alias)
|
|
||||||
entities.append(entity)
|
|
||||||
normalizations.append(False)
|
|
||||||
|
|
||||||
return aliases, entities, normalizations
|
|
||||||
|
|
||||||
|
|
||||||
def _capitalize_first(text):
|
|
||||||
if not text:
|
|
||||||
return None
|
|
||||||
result = text[0].capitalize()
|
|
||||||
if len(result) > 0:
|
|
||||||
result += text[1:]
|
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
def create_training_and_desc(
|
|
||||||
wp_input, def_input, desc_output, training_output, parse_desc, limit=None
|
|
||||||
):
|
|
||||||
wp_to_id = io.read_title_to_id(def_input)
|
|
||||||
_process_wikipedia_texts(
|
|
||||||
wp_input, wp_to_id, desc_output, training_output, parse_desc, limit
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _process_wikipedia_texts(
|
|
||||||
wikipedia_input, wp_to_id, output, training_output, parse_descriptions, limit=None
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Read the XML wikipedia data to parse out training data:
|
|
||||||
raw text data + positive instances
|
|
||||||
"""
|
|
||||||
|
|
||||||
read_ids = set()
|
|
||||||
|
|
||||||
with output.open("a", encoding="utf8") as descr_file, training_output.open(
|
|
||||||
"w", encoding="utf8"
|
|
||||||
) as entity_file:
|
|
||||||
if parse_descriptions:
|
|
||||||
_write_training_description(descr_file, "WD_id", "description")
|
|
||||||
with bz2.open(wikipedia_input, mode="rb") as file:
|
|
||||||
article_count = 0
|
|
||||||
article_text = ""
|
|
||||||
article_title = None
|
|
||||||
article_id = None
|
|
||||||
reading_text = False
|
|
||||||
reading_revision = False
|
|
||||||
|
|
||||||
for line in file:
|
|
||||||
clean_line = line.strip().decode("utf-8")
|
|
||||||
|
|
||||||
if clean_line == "<revision>":
|
|
||||||
reading_revision = True
|
|
||||||
elif clean_line == "</revision>":
|
|
||||||
reading_revision = False
|
|
||||||
|
|
||||||
# Start reading new page
|
|
||||||
if clean_line == "<page>":
|
|
||||||
article_text = ""
|
|
||||||
article_title = None
|
|
||||||
article_id = None
|
|
||||||
# finished reading this page
|
|
||||||
elif clean_line == "</page>":
|
|
||||||
if article_id:
|
|
||||||
clean_text, entities = _process_wp_text(
|
|
||||||
article_title, article_text, wp_to_id
|
|
||||||
)
|
|
||||||
if clean_text is not None and entities is not None:
|
|
||||||
_write_training_entities(
|
|
||||||
entity_file, article_id, clean_text, entities
|
|
||||||
)
|
|
||||||
|
|
||||||
if article_title in wp_to_id and parse_descriptions:
|
|
||||||
description = " ".join(
|
|
||||||
clean_text[:1000].split(" ")[:-1]
|
|
||||||
)
|
|
||||||
_write_training_description(
|
|
||||||
descr_file, wp_to_id[article_title], description
|
|
||||||
)
|
|
||||||
article_count += 1
|
|
||||||
if article_count % 10000 == 0 and article_count > 0:
|
|
||||||
logger.info(
|
|
||||||
"Processed {} articles".format(article_count)
|
|
||||||
)
|
|
||||||
if limit and article_count >= limit:
|
|
||||||
break
|
|
||||||
article_text = ""
|
|
||||||
article_title = None
|
|
||||||
article_id = None
|
|
||||||
reading_text = False
|
|
||||||
reading_revision = False
|
|
||||||
|
|
||||||
# start reading text within a page
|
|
||||||
if "<text" in clean_line:
|
|
||||||
reading_text = True
|
|
||||||
|
|
||||||
if reading_text:
|
|
||||||
article_text += " " + clean_line
|
|
||||||
|
|
||||||
# stop reading text within a page (we assume a new page doesn't start on the same line)
|
|
||||||
if "</text" in clean_line:
|
|
||||||
reading_text = False
|
|
||||||
|
|
||||||
# read the ID of this article (outside the revision portion of the document)
|
|
||||||
if not reading_revision:
|
|
||||||
ids = id_regex.search(clean_line)
|
|
||||||
if ids:
|
|
||||||
article_id = ids[0]
|
|
||||||
if article_id in read_ids:
|
|
||||||
logger.info(
|
|
||||||
"Found duplicate article ID", article_id, clean_line
|
|
||||||
) # This should never happen ...
|
|
||||||
read_ids.add(article_id)
|
|
||||||
|
|
||||||
# read the title of this article (outside the revision portion of the document)
|
|
||||||
if not reading_revision:
|
|
||||||
titles = title_regex.search(clean_line)
|
|
||||||
if titles:
|
|
||||||
article_title = titles[0].strip()
|
|
||||||
logger.info("Finished. Processed {} articles".format(article_count))
|
|
||||||
|
|
||||||
|
|
||||||
def _process_wp_text(article_title, article_text, wp_to_id):
|
|
||||||
# ignore meta Wikipedia pages
|
|
||||||
if ns_regex.match(article_title):
|
|
||||||
return None, None
|
|
||||||
|
|
||||||
# remove the text tags
|
|
||||||
text_search = text_regex.search(article_text)
|
|
||||||
if text_search is None:
|
|
||||||
return None, None
|
|
||||||
text = text_search.group(0)
|
|
||||||
|
|
||||||
# stop processing if this is a redirect page
|
|
||||||
if text.startswith("#REDIRECT"):
|
|
||||||
return None, None
|
|
||||||
|
|
||||||
# get the raw text without markup etc, keeping only interwiki links
|
|
||||||
clean_text, entities = _remove_links(_get_clean_wp_text(text), wp_to_id)
|
|
||||||
return clean_text, entities
|
|
||||||
|
|
||||||
|
|
||||||
def _get_clean_wp_text(article_text):
|
|
||||||
clean_text = article_text.strip()
|
|
||||||
|
|
||||||
# remove bolding & italic markup
|
|
||||||
clean_text = clean_text.replace("'''", "")
|
|
||||||
clean_text = clean_text.replace("''", "")
|
|
||||||
|
|
||||||
# remove nested {{info}} statements by removing the inner/smallest ones first and iterating
|
|
||||||
try_again = True
|
|
||||||
previous_length = len(clean_text)
|
|
||||||
while try_again:
|
|
||||||
clean_text = info_regex.sub(
|
|
||||||
"", clean_text
|
|
||||||
) # non-greedy match excluding a nested {
|
|
||||||
if len(clean_text) < previous_length:
|
|
||||||
try_again = True
|
|
||||||
else:
|
|
||||||
try_again = False
|
|
||||||
previous_length = len(clean_text)
|
|
||||||
|
|
||||||
# remove HTML comments
|
|
||||||
clean_text = html_regex.sub("", clean_text)
|
|
||||||
|
|
||||||
# remove Category and File statements
|
|
||||||
clean_text = category_regex.sub("", clean_text)
|
|
||||||
clean_text = file_regex.sub("", clean_text)
|
|
||||||
|
|
||||||
# remove multiple =
|
|
||||||
while "==" in clean_text:
|
|
||||||
clean_text = clean_text.replace("==", "=")
|
|
||||||
|
|
||||||
clean_text = clean_text.replace(". =", ".")
|
|
||||||
clean_text = clean_text.replace(" = ", ". ")
|
|
||||||
clean_text = clean_text.replace("= ", ".")
|
|
||||||
clean_text = clean_text.replace(" =", "")
|
|
||||||
|
|
||||||
# remove refs (non-greedy match)
|
|
||||||
clean_text = ref_regex.sub("", clean_text)
|
|
||||||
clean_text = ref_2_regex.sub("", clean_text)
|
|
||||||
|
|
||||||
# remove additional wikiformatting
|
|
||||||
clean_text = re.sub(r"<blockquote>", "", clean_text)
|
|
||||||
clean_text = re.sub(r"</blockquote>", "", clean_text)
|
|
||||||
|
|
||||||
# change special characters back to normal ones
|
|
||||||
clean_text = clean_text.replace(r"<", "<")
|
|
||||||
clean_text = clean_text.replace(r">", ">")
|
|
||||||
clean_text = clean_text.replace(r""", '"')
|
|
||||||
clean_text = clean_text.replace(r"&nbsp;", " ")
|
|
||||||
clean_text = clean_text.replace(r"&", "&")
|
|
||||||
|
|
||||||
# remove multiple spaces
|
|
||||||
while " " in clean_text:
|
|
||||||
clean_text = clean_text.replace(" ", " ")
|
|
||||||
|
|
||||||
return clean_text.strip()
|
|
||||||
|
|
||||||
|
|
||||||
def _remove_links(clean_text, wp_to_id):
|
|
||||||
# read the text char by char to get the right offsets for the interwiki links
|
|
||||||
entities = []
|
|
||||||
final_text = ""
|
|
||||||
open_read = 0
|
|
||||||
reading_text = True
|
|
||||||
reading_entity = False
|
|
||||||
reading_mention = False
|
|
||||||
reading_special_case = False
|
|
||||||
entity_buffer = ""
|
|
||||||
mention_buffer = ""
|
|
||||||
for index, letter in enumerate(clean_text):
|
|
||||||
if letter == "[":
|
|
||||||
open_read += 1
|
|
||||||
elif letter == "]":
|
|
||||||
open_read -= 1
|
|
||||||
elif letter == "|":
|
|
||||||
if reading_text:
|
|
||||||
final_text += letter
|
|
||||||
# switch from reading entity to mention in the [[entity|mention]] pattern
|
|
||||||
elif reading_entity:
|
|
||||||
reading_text = False
|
|
||||||
reading_entity = False
|
|
||||||
reading_mention = True
|
|
||||||
else:
|
|
||||||
reading_special_case = True
|
|
||||||
else:
|
|
||||||
if reading_entity:
|
|
||||||
entity_buffer += letter
|
|
||||||
elif reading_mention:
|
|
||||||
mention_buffer += letter
|
|
||||||
elif reading_text:
|
|
||||||
final_text += letter
|
|
||||||
else:
|
|
||||||
raise ValueError("Not sure at point", clean_text[index - 2 : index + 2])
|
|
||||||
|
|
||||||
if open_read > 2:
|
|
||||||
reading_special_case = True
|
|
||||||
|
|
||||||
if open_read == 2 and reading_text:
|
|
||||||
reading_text = False
|
|
||||||
reading_entity = True
|
|
||||||
reading_mention = False
|
|
||||||
|
|
||||||
# we just finished reading an entity
|
|
||||||
if open_read == 0 and not reading_text:
|
|
||||||
if "#" in entity_buffer or entity_buffer.startswith(":"):
|
|
||||||
reading_special_case = True
|
|
||||||
# Ignore cases with nested structures like File: handles etc
|
|
||||||
if not reading_special_case:
|
|
||||||
if not mention_buffer:
|
|
||||||
mention_buffer = entity_buffer
|
|
||||||
start = len(final_text)
|
|
||||||
end = start + len(mention_buffer)
|
|
||||||
qid = wp_to_id.get(entity_buffer, None)
|
|
||||||
if qid:
|
|
||||||
entities.append((mention_buffer, qid, start, end))
|
|
||||||
final_text += mention_buffer
|
|
||||||
|
|
||||||
entity_buffer = ""
|
|
||||||
mention_buffer = ""
|
|
||||||
|
|
||||||
reading_text = True
|
|
||||||
reading_entity = False
|
|
||||||
reading_mention = False
|
|
||||||
reading_special_case = False
|
|
||||||
return final_text, entities
|
|
||||||
|
|
||||||
|
|
||||||
def _write_training_description(outputfile, qid, description):
|
|
||||||
if description is not None:
|
|
||||||
line = str(qid) + "|" + description + "\n"
|
|
||||||
outputfile.write(line)
|
|
||||||
|
|
||||||
|
|
||||||
def _write_training_entities(outputfile, article_id, clean_text, entities):
|
|
||||||
entities_data = [
|
|
||||||
{"alias": ent[0], "entity": ent[1], "start": ent[2], "end": ent[3]}
|
|
||||||
for ent in entities
|
|
||||||
]
|
|
||||||
line = (
|
|
||||||
json.dumps(
|
|
||||||
{
|
|
||||||
"article_id": article_id,
|
|
||||||
"clean_text": clean_text,
|
|
||||||
"entities": entities_data,
|
|
||||||
},
|
|
||||||
ensure_ascii=False,
|
|
||||||
)
|
|
||||||
+ "\n"
|
|
||||||
)
|
|
||||||
outputfile.write(line)
|
|
||||||
|
|
||||||
|
|
||||||
def read_training_indices(entity_file_path):
|
|
||||||
""" This method creates two lists of indices into the training file: one with indices for the
|
|
||||||
training examples, and one for the dev examples."""
|
|
||||||
train_indices = []
|
|
||||||
dev_indices = []
|
|
||||||
|
|
||||||
with entity_file_path.open("r", encoding="utf8") as file:
|
|
||||||
for i, line in enumerate(file):
|
|
||||||
example = json.loads(line)
|
|
||||||
article_id = example["article_id"]
|
|
||||||
clean_text = example["clean_text"]
|
|
||||||
|
|
||||||
if is_valid_article(clean_text):
|
|
||||||
if is_dev(article_id):
|
|
||||||
dev_indices.append(i)
|
|
||||||
else:
|
|
||||||
train_indices.append(i)
|
|
||||||
|
|
||||||
return train_indices, dev_indices
|
|
||||||
|
|
||||||
|
|
||||||
def read_el_docs_golds(nlp, entity_file_path, dev, line_ids, kb, labels_discard=None):
|
|
||||||
""" This method provides training/dev examples that correspond to the entity annotations found by the nlp object.
|
|
||||||
For training, it will include both positive and negative examples by using the candidate generator from the kb.
|
|
||||||
For testing (kb=None), it will include all positive examples only."""
|
|
||||||
if not labels_discard:
|
|
||||||
labels_discard = []
|
|
||||||
|
|
||||||
texts = []
|
|
||||||
entities_list = []
|
|
||||||
|
|
||||||
with entity_file_path.open("r", encoding="utf8") as file:
|
|
||||||
for i, line in enumerate(file):
|
|
||||||
if i in line_ids:
|
|
||||||
example = json.loads(line)
|
|
||||||
article_id = example["article_id"]
|
|
||||||
clean_text = example["clean_text"]
|
|
||||||
entities = example["entities"]
|
|
||||||
|
|
||||||
if dev != is_dev(article_id) or not is_valid_article(clean_text):
|
|
||||||
continue
|
|
||||||
|
|
||||||
texts.append(clean_text)
|
|
||||||
entities_list.append(entities)
|
|
||||||
|
|
||||||
docs = nlp.pipe(texts, batch_size=50)
|
|
||||||
|
|
||||||
for doc, entities in zip(docs, entities_list):
|
|
||||||
gold = _get_gold_parse(doc, entities, dev=dev, kb=kb, labels_discard=labels_discard)
|
|
||||||
if gold and len(gold.links) > 0:
|
|
||||||
yield doc, gold
|
|
||||||
|
|
||||||
|
|
||||||
def _get_gold_parse(doc, entities, dev, kb, labels_discard):
|
|
||||||
gold_entities = {}
|
|
||||||
tagged_ent_positions = {
|
|
||||||
(ent.start_char, ent.end_char): ent
|
|
||||||
for ent in doc.ents
|
|
||||||
if ent.label_ not in labels_discard
|
|
||||||
}
|
|
||||||
|
|
||||||
for entity in entities:
|
|
||||||
entity_id = entity["entity"]
|
|
||||||
alias = entity["alias"]
|
|
||||||
start = entity["start"]
|
|
||||||
end = entity["end"]
|
|
||||||
|
|
||||||
candidate_ids = []
|
|
||||||
if kb and not dev:
|
|
||||||
candidates = kb.get_candidates(alias)
|
|
||||||
candidate_ids = [cand.entity_ for cand in candidates]
|
|
||||||
|
|
||||||
tagged_ent = tagged_ent_positions.get((start, end), None)
|
|
||||||
if tagged_ent:
|
|
||||||
# TODO: check that alias == doc.text[start:end]
|
|
||||||
should_add_ent = (dev or entity_id in candidate_ids) and is_valid_sentence(
|
|
||||||
tagged_ent.sent.text
|
|
||||||
)
|
|
||||||
|
|
||||||
if should_add_ent:
|
|
||||||
value_by_id = {entity_id: 1.0}
|
|
||||||
if not dev:
|
|
||||||
random.shuffle(candidate_ids)
|
|
||||||
value_by_id.update(
|
|
||||||
{kb_id: 0.0 for kb_id in candidate_ids if kb_id != entity_id}
|
|
||||||
)
|
|
||||||
gold_entities[(start, end)] = value_by_id
|
|
||||||
|
|
||||||
return GoldParse(doc, links=gold_entities)
|
|
||||||
|
|
||||||
|
|
||||||
def is_dev(article_id):
|
|
||||||
if not article_id:
|
|
||||||
return False
|
|
||||||
return article_id.endswith("3")
|
|
||||||
|
|
||||||
|
|
||||||
def is_valid_article(doc_text):
|
|
||||||
# custom length cut-off
|
|
||||||
return 10 < len(doc_text) < 30000
|
|
||||||
|
|
||||||
|
|
||||||
def is_valid_sentence(sent_text):
|
|
||||||
if not 10 < len(sent_text) < 3000:
|
|
||||||
# custom length cut-off
|
|
||||||
return False
|
|
||||||
|
|
||||||
if sent_text.strip().startswith("*") or sent_text.strip().startswith("#"):
|
|
||||||
# remove 'enumeration' sentences (occurs often on Wikipedia)
|
|
||||||
return False
|
|
||||||
|
|
||||||
return True
|
|
|
@ -1,15 +1,15 @@
|
||||||
#!/usr/bin/env python
|
#!/usr/bin/env python
|
||||||
# coding: utf8
|
# coding: utf8
|
||||||
|
|
||||||
"""Example of defining and (pre)training spaCy's knowledge base,
|
"""Example of defining a knowledge base in spaCy,
|
||||||
which is needed to implement entity linking functionality.
|
which is needed to implement entity linking functionality.
|
||||||
|
|
||||||
For more details, see the documentation:
|
For more details, see the documentation:
|
||||||
* Knowledge base: https://spacy.io/api/kb
|
* Knowledge base: https://spacy.io/api/kb
|
||||||
* Entity Linking: https://spacy.io/usage/linguistic-features#entity-linking
|
* Entity Linking: https://spacy.io/usage/linguistic-features#entity-linking
|
||||||
|
|
||||||
Compatible with: spaCy v2.2.3
|
Compatible with: spaCy v2.2.4
|
||||||
Last tested with: v2.2.3
|
Last tested with: v2.2.4
|
||||||
"""
|
"""
|
||||||
from __future__ import unicode_literals, print_function
|
from __future__ import unicode_literals, print_function
|
||||||
|
|
||||||
|
@ -20,24 +20,18 @@ from spacy.vocab import Vocab
|
||||||
import spacy
|
import spacy
|
||||||
from spacy.kb import KnowledgeBase
|
from spacy.kb import KnowledgeBase
|
||||||
|
|
||||||
from bin.wiki_entity_linking.train_descriptions import EntityEncoder
|
|
||||||
|
|
||||||
|
|
||||||
# Q2146908 (Russ Cochran): American golfer
|
# Q2146908 (Russ Cochran): American golfer
|
||||||
# Q7381115 (Russ Cochran): publisher
|
# Q7381115 (Russ Cochran): publisher
|
||||||
ENTITIES = {"Q2146908": ("American golfer", 342), "Q7381115": ("publisher", 17)}
|
ENTITIES = {"Q2146908": ("American golfer", 342), "Q7381115": ("publisher", 17)}
|
||||||
|
|
||||||
INPUT_DIM = 300 # dimension of pretrained input vectors
|
|
||||||
DESC_WIDTH = 64 # dimension of output entity vectors
|
|
||||||
|
|
||||||
|
|
||||||
@plac.annotations(
|
@plac.annotations(
|
||||||
model=("Model name, should have pretrained word embeddings", "positional", None, str),
|
model=("Model name, should have pretrained word embeddings", "positional", None, str),
|
||||||
output_dir=("Optional output directory", "option", "o", Path),
|
output_dir=("Optional output directory", "option", "o", Path),
|
||||||
n_iter=("Number of training iterations", "option", "n", int),
|
|
||||||
)
|
)
|
||||||
def main(model=None, output_dir=None, n_iter=50):
|
def main(model=None, output_dir=None):
|
||||||
"""Load the model, create the KB and pretrain the entity encodings.
|
"""Load the model and create the KB with pre-defined entity encodings.
|
||||||
If an output_dir is provided, the KB will be stored there in a file 'kb'.
|
If an output_dir is provided, the KB will be stored there in a file 'kb'.
|
||||||
The updated vocab will also be written to a directory in the output_dir."""
|
The updated vocab will also be written to a directory in the output_dir."""
|
||||||
|
|
||||||
|
@ -51,33 +45,23 @@ def main(model=None, output_dir=None, n_iter=50):
|
||||||
" cf. https://spacy.io/usage/models#languages."
|
" cf. https://spacy.io/usage/models#languages."
|
||||||
)
|
)
|
||||||
|
|
||||||
kb = KnowledgeBase(vocab=nlp.vocab)
|
# You can change the dimension of vectors in your KB by using an encoder that changes the dimensionality.
|
||||||
|
# For simplicity, we'll just use the original vector dimension here instead.
|
||||||
|
vectors_dim = nlp.vocab.vectors.shape[1]
|
||||||
|
kb = KnowledgeBase(vocab=nlp.vocab, entity_vector_length=vectors_dim)
|
||||||
|
|
||||||
# set up the data
|
# set up the data
|
||||||
entity_ids = []
|
entity_ids = []
|
||||||
descriptions = []
|
descr_embeddings = []
|
||||||
freqs = []
|
freqs = []
|
||||||
for key, value in ENTITIES.items():
|
for key, value in ENTITIES.items():
|
||||||
desc, freq = value
|
desc, freq = value
|
||||||
entity_ids.append(key)
|
entity_ids.append(key)
|
||||||
descriptions.append(desc)
|
descr_embeddings.append(nlp(desc).vector)
|
||||||
freqs.append(freq)
|
freqs.append(freq)
|
||||||
|
|
||||||
# training entity description encodings
|
|
||||||
# this part can easily be replaced with a custom entity encoder
|
|
||||||
encoder = EntityEncoder(
|
|
||||||
nlp=nlp,
|
|
||||||
input_dim=INPUT_DIM,
|
|
||||||
desc_width=DESC_WIDTH,
|
|
||||||
epochs=n_iter,
|
|
||||||
)
|
|
||||||
encoder.train(description_list=descriptions, to_print=True)
|
|
||||||
|
|
||||||
# get the pretrained entity vectors
|
|
||||||
embeddings = encoder.apply_encoder(descriptions)
|
|
||||||
|
|
||||||
# set the entities, can also be done by calling `kb.add_entity` for each entity
|
# set the entities, can also be done by calling `kb.add_entity` for each entity
|
||||||
kb.set_entities(entity_list=entity_ids, freq_list=freqs, vector_list=embeddings)
|
kb.set_entities(entity_list=entity_ids, freq_list=freqs, vector_list=descr_embeddings)
|
||||||
|
|
||||||
# adding aliases, the entities need to be defined in the KB beforehand
|
# adding aliases, the entities need to be defined in the KB beforehand
|
||||||
kb.add_alias(
|
kb.add_alias(
|
||||||
|
@ -113,8 +97,8 @@ def main(model=None, output_dir=None, n_iter=50):
|
||||||
vocab2 = Vocab().from_disk(vocab_path)
|
vocab2 = Vocab().from_disk(vocab_path)
|
||||||
kb2 = KnowledgeBase(vocab=vocab2)
|
kb2 = KnowledgeBase(vocab=vocab2)
|
||||||
kb2.load_bulk(kb_path)
|
kb2.load_bulk(kb_path)
|
||||||
_print_kb(kb2)
|
|
||||||
print()
|
print()
|
||||||
|
_print_kb(kb2)
|
||||||
|
|
||||||
|
|
||||||
def _print_kb(kb):
|
def _print_kb(kb):
|
||||||
|
@ -126,6 +110,5 @@ if __name__ == "__main__":
|
||||||
plac.call(main)
|
plac.call(main)
|
||||||
|
|
||||||
# Expected output:
|
# Expected output:
|
||||||
|
|
||||||
# 2 kb entities: ['Q2146908', 'Q7381115']
|
# 2 kb entities: ['Q2146908', 'Q7381115']
|
||||||
# 1 kb aliases: ['Russ Cochran']
|
# 1 kb aliases: ['Russ Cochran']
|
|
@ -1,15 +1,15 @@
|
||||||
#!/usr/bin/env python
|
#!/usr/bin/env python
|
||||||
# coding: utf8
|
# coding: utf8
|
||||||
|
|
||||||
"""Example of training spaCy's entity linker, starting off with an
|
"""Example of training spaCy's entity linker, starting off with a predefined
|
||||||
existing model and a pre-defined knowledge base.
|
knowledge base and corresponding vocab, and a blank English model.
|
||||||
|
|
||||||
For more details, see the documentation:
|
For more details, see the documentation:
|
||||||
* Training: https://spacy.io/usage/training
|
* Training: https://spacy.io/usage/training
|
||||||
* Entity Linking: https://spacy.io/usage/linguistic-features#entity-linking
|
* Entity Linking: https://spacy.io/usage/linguistic-features#entity-linking
|
||||||
|
|
||||||
Compatible with: spaCy v2.2.3
|
Compatible with: spaCy v2.2.4
|
||||||
Last tested with: v2.2.3
|
Last tested with: v2.2.4
|
||||||
"""
|
"""
|
||||||
from __future__ import unicode_literals, print_function
|
from __future__ import unicode_literals, print_function
|
||||||
|
|
||||||
|
@ -17,13 +17,11 @@ import plac
|
||||||
import random
|
import random
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from spacy.symbols import PERSON
|
|
||||||
from spacy.vocab import Vocab
|
from spacy.vocab import Vocab
|
||||||
|
|
||||||
import spacy
|
import spacy
|
||||||
from spacy.kb import KnowledgeBase
|
from spacy.kb import KnowledgeBase
|
||||||
from spacy.pipeline import EntityRuler
|
from spacy.pipeline import EntityRuler
|
||||||
from spacy.tokens import Span
|
|
||||||
from spacy.util import minibatch, compounding
|
from spacy.util import minibatch, compounding
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -111,6 +111,27 @@ start.
|
||||||
https://github.com/explosion/spaCy/tree/master/examples/training/train_new_entity_type.py
|
https://github.com/explosion/spaCy/tree/master/examples/training/train_new_entity_type.py
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### Creating a Knowledge Base for Named Entity Linking {#kb}
|
||||||
|
|
||||||
|
This example shows how to create a knowledge base in spaCy,
|
||||||
|
which is needed to implement entity linking functionality.
|
||||||
|
It requires as input a spaCy model with pretrained word vectors,
|
||||||
|
and it stores the KB to file (if an `output_dir` is provided).
|
||||||
|
|
||||||
|
```python
|
||||||
|
https://github.com/explosion/spaCy/tree/master/examples/training/create_kb.py
|
||||||
|
```
|
||||||
|
|
||||||
|
### Training spaCy's Named Entity Linker {#nel}
|
||||||
|
|
||||||
|
This example shows how to train spaCy's entity linker with your own custom
|
||||||
|
examples, starting off with a predefined knowledge base and its vocab,
|
||||||
|
and using a blank `English` class.
|
||||||
|
|
||||||
|
```python
|
||||||
|
https://github.com/explosion/spaCy/tree/master/examples/training/train_entity_linker.py
|
||||||
|
```
|
||||||
|
|
||||||
### Training spaCy's Dependency Parser {#parser}
|
### Training spaCy's Dependency Parser {#parser}
|
||||||
|
|
||||||
This example shows how to update spaCy's dependency parser, starting off with an
|
This example shows how to update spaCy's dependency parser, starting off with an
|
||||||
|
|
|
@ -579,9 +579,7 @@ import DisplacyEntHtml from 'images/displacy-ent2.html'
|
||||||
|
|
||||||
To ground the named entities into the "real world", spaCy provides functionality
|
To ground the named entities into the "real world", spaCy provides functionality
|
||||||
to perform entity linking, which resolves a textual entity to a unique
|
to perform entity linking, which resolves a textual entity to a unique
|
||||||
identifier from a knowledge base (KB). The
|
identifier from a knowledge base (KB). You can create your own
|
||||||
[processing scripts](https://github.com/explosion/spaCy/tree/master/bin/wiki_entity_linking)
|
|
||||||
we provide use WikiData identifiers, but you can create your own
|
|
||||||
[`KnowledgeBase`](/api/kb) and
|
[`KnowledgeBase`](/api/kb) and
|
||||||
[train a new Entity Linking model](/usage/training#entity-linker) using that
|
[train a new Entity Linking model](/usage/training#entity-linker) using that
|
||||||
custom-made KB.
|
custom-made KB.
|
||||||
|
|
|
@ -347,9 +347,9 @@ your data** to find a solution that works best for you.
|
||||||
### Updating the Named Entity Recognizer {#example-train-ner}
|
### Updating the Named Entity Recognizer {#example-train-ner}
|
||||||
|
|
||||||
This example shows how to update spaCy's entity recognizer with your own
|
This example shows how to update spaCy's entity recognizer with your own
|
||||||
examples, starting off with an existing, pretrained model, or from scratch
|
examples, starting off with an existing, pretrained model, or from scratch using
|
||||||
using a blank `Language` class. To do this, you'll need **example texts** and
|
a blank `Language` class. To do this, you'll need **example texts** and the
|
||||||
the **character offsets** and **labels** of each entity contained in the texts.
|
**character offsets** and **labels** of each entity contained in the texts.
|
||||||
|
|
||||||
```python
|
```python
|
||||||
https://github.com/explosion/spaCy/tree/master/examples/training/train_ner.py
|
https://github.com/explosion/spaCy/tree/master/examples/training/train_ner.py
|
||||||
|
@ -440,8 +440,8 @@ https://github.com/explosion/spaCy/tree/master/examples/training/train_parser.py
|
||||||
training the parser.
|
training the parser.
|
||||||
2. **Add the dependency labels** to the parser using the
|
2. **Add the dependency labels** to the parser using the
|
||||||
[`add_label`](/api/dependencyparser#add_label) method. If you're starting off
|
[`add_label`](/api/dependencyparser#add_label) method. If you're starting off
|
||||||
with a pretrained spaCy model, this is usually not necessary – but it
|
with a pretrained spaCy model, this is usually not necessary – but it doesn't
|
||||||
doesn't hurt either, just to be safe.
|
hurt either, just to be safe.
|
||||||
3. **Shuffle and loop over** the examples. For each example, **update the
|
3. **Shuffle and loop over** the examples. For each example, **update the
|
||||||
model** by calling [`nlp.update`](/api/language#update), which steps through
|
model** by calling [`nlp.update`](/api/language#update), which steps through
|
||||||
the words of the input. At each word, it makes a **prediction**. It then
|
the words of the input. At each word, it makes a **prediction**. It then
|
||||||
|
@ -605,16 +605,16 @@ To train an entity linking model, you first need to define a knowledge base
|
||||||
|
|
||||||
A KB consists of a list of entities with unique identifiers. Each such entity
|
A KB consists of a list of entities with unique identifiers. Each such entity
|
||||||
has an entity vector that will be used to measure similarity with the context in
|
has an entity vector that will be used to measure similarity with the context in
|
||||||
which an entity is used. These vectors are pretrained and stored in the KB
|
which an entity is used. These vectors have a fixed length and are stored in the
|
||||||
before the entity linking model will be trained.
|
KB.
|
||||||
|
|
||||||
The following example shows how to build a knowledge base from scratch, given a
|
The following example shows how to build a knowledge base from scratch, given a
|
||||||
list of entities and potential aliases. The script further demonstrates how to
|
list of entities and potential aliases. The script requires an `nlp` model with
|
||||||
pretrain and store the entity vectors. To run this example, the script needs
|
pretrained word vectors to obtain an encoding of an entity's description as its
|
||||||
access to a `vocab` instance or an `nlp` model with pretrained word embeddings.
|
vector.
|
||||||
|
|
||||||
```python
|
```python
|
||||||
https://github.com/explosion/spaCy/tree/master/examples/training/pretrain_kb.py
|
https://github.com/explosion/spaCy/tree/master/examples/training/create_kb.py
|
||||||
```
|
```
|
||||||
|
|
||||||
#### Step by step guide {#step-by-step-kb}
|
#### Step by step guide {#step-by-step-kb}
|
||||||
|
|
Loading…
Reference in New Issue
Block a user