2019-08-13 16:38:59 +03:00
|
|
|
# coding: utf-8
|
|
|
|
"""Script to take a previously created Knowledge Base and train an entity linking
|
|
|
|
pipeline. The provided KB directory should hold the kb, the original nlp object and
|
|
|
|
its vocab used to create the KB, and a few auxiliary files such as the entity definitions,
|
|
|
|
as created by the script `wikidata_create_kb`.
|
|
|
|
|
|
|
|
For the Wikipedia dump: get enwiki-latest-pages-articles-multistream.xml.bz2
|
|
|
|
from https://dumps.wikimedia.org/enwiki/latest/
|
|
|
|
|
|
|
|
"""
|
|
|
|
from __future__ import unicode_literals
|
|
|
|
|
|
|
|
import random
|
2019-09-13 18:03:57 +03:00
|
|
|
import logging
|
2019-08-13 16:38:59 +03:00
|
|
|
from pathlib import Path
|
|
|
|
import plac
|
|
|
|
|
|
|
|
from bin.wiki_entity_linking import training_set_creator
|
2019-09-13 18:03:57 +03:00
|
|
|
from bin.wiki_entity_linking import TRAINING_DATA_FILE, KB_MODEL_DIR, KB_FILE, LOG_FORMAT, OUTPUT_MODEL_DIR
|
|
|
|
from bin.wiki_entity_linking.entity_linker_evaluation import measure_performance, measure_baselines
|
|
|
|
from bin.wiki_entity_linking.kb_creator import read_nlp_kb
|
2019-08-13 16:38:59 +03:00
|
|
|
|
|
|
|
from spacy.util import minibatch, compounding
|
|
|
|
|
2019-09-13 18:03:57 +03:00
|
|
|
logger = logging.getLogger(__name__)
|
2019-08-13 16:38:59 +03:00
|
|
|
|
|
|
|
|
|
|
|
@plac.annotations(
|
|
|
|
dir_kb=("Directory with KB, NLP and related files", "positional", None, Path),
|
|
|
|
output_dir=("Output directory", "option", "o", Path),
|
|
|
|
loc_training=("Location to training data", "option", "k", Path),
|
|
|
|
epochs=("Number of training iterations (default 10)", "option", "e", int),
|
|
|
|
dropout=("Dropout to prevent overfitting (default 0.5)", "option", "p", float),
|
|
|
|
lr=("Learning rate (default 0.005)", "option", "n", float),
|
|
|
|
l2=("L2 regularization", "option", "r", float),
|
|
|
|
train_inst=("# training instances (default 90% of all)", "option", "t", int),
|
|
|
|
dev_inst=("# test instances (default 10% of all)", "option", "d", int),
|
|
|
|
)
|
|
|
|
def main(
|
|
|
|
dir_kb,
|
|
|
|
output_dir=None,
|
|
|
|
loc_training=None,
|
|
|
|
epochs=10,
|
|
|
|
dropout=0.5,
|
|
|
|
lr=0.005,
|
|
|
|
l2=1e-6,
|
|
|
|
train_inst=None,
|
|
|
|
dev_inst=None,
|
|
|
|
):
|
2019-09-13 18:03:57 +03:00
|
|
|
logger.info("Creating Entity Linker with Wikipedia and WikiData")
|
|
|
|
|
|
|
|
output_dir = Path(output_dir) if output_dir else dir_kb
|
|
|
|
training_path = loc_training if loc_training else output_dir / TRAINING_DATA_FILE
|
|
|
|
nlp_dir = dir_kb / KB_MODEL_DIR
|
|
|
|
kb_path = output_dir / KB_FILE
|
|
|
|
nlp_output_dir = output_dir / OUTPUT_MODEL_DIR
|
2019-08-13 16:38:59 +03:00
|
|
|
|
|
|
|
# STEP 0: set up IO
|
2019-09-13 18:03:57 +03:00
|
|
|
if not output_dir.exists():
|
2019-08-13 16:38:59 +03:00
|
|
|
output_dir.mkdir()
|
|
|
|
|
|
|
|
# STEP 1 : load the NLP object
|
2019-09-13 18:03:57 +03:00
|
|
|
logger.info("STEP 1: loading model from {}".format(nlp_dir))
|
|
|
|
nlp, kb = read_nlp_kb(nlp_dir, kb_path)
|
2019-08-13 16:38:59 +03:00
|
|
|
|
|
|
|
# check that there is a NER component in the pipeline
|
|
|
|
if "ner" not in nlp.pipe_names:
|
2019-08-21 15:00:37 +03:00
|
|
|
raise ValueError("The `nlp` object should have a pre-trained `ner` component.")
|
2019-08-13 16:38:59 +03:00
|
|
|
|
2019-09-13 18:03:57 +03:00
|
|
|
# STEP 2: create a training dataset from WP
|
|
|
|
logger.info("STEP 2: reading training dataset from {}".format(training_path))
|
|
|
|
|
2019-08-13 16:38:59 +03:00
|
|
|
train_data = training_set_creator.read_training(
|
2019-09-13 18:03:57 +03:00
|
|
|
nlp=nlp,
|
|
|
|
entity_file_path=training_path,
|
|
|
|
dev=False,
|
|
|
|
limit=train_inst,
|
|
|
|
kb=kb,
|
2019-08-13 16:38:59 +03:00
|
|
|
)
|
|
|
|
|
|
|
|
# for testing, get all pos instances, whether or not they are in the kb
|
|
|
|
dev_data = training_set_creator.read_training(
|
2019-09-13 18:03:57 +03:00
|
|
|
nlp=nlp,
|
|
|
|
entity_file_path=training_path,
|
|
|
|
dev=True,
|
|
|
|
limit=dev_inst,
|
|
|
|
kb=kb,
|
2019-08-13 16:38:59 +03:00
|
|
|
)
|
|
|
|
|
2019-09-13 18:03:57 +03:00
|
|
|
# STEP 3: create and train the entity linking pipe
|
|
|
|
logger.info("STEP 3: training Entity Linking pipe")
|
2019-08-13 16:38:59 +03:00
|
|
|
|
|
|
|
el_pipe = nlp.create_pipe(
|
|
|
|
name="entity_linker", config={"pretrained_vectors": nlp.vocab.vectors.name}
|
|
|
|
)
|
|
|
|
el_pipe.set_kb(kb)
|
|
|
|
nlp.add_pipe(el_pipe, last=True)
|
|
|
|
|
|
|
|
other_pipes = [pipe for pipe in nlp.pipe_names if pipe != "entity_linker"]
|
|
|
|
with nlp.disable_pipes(*other_pipes): # only train Entity Linking
|
|
|
|
optimizer = nlp.begin_training()
|
|
|
|
optimizer.learn_rate = lr
|
|
|
|
optimizer.L2 = l2
|
|
|
|
|
2019-09-13 18:03:57 +03:00
|
|
|
logger.info("Training on {} articles".format(len(train_data)))
|
|
|
|
logger.info("Dev testing on {} articles".format(len(dev_data)))
|
2019-08-13 16:38:59 +03:00
|
|
|
|
2019-09-13 18:03:57 +03:00
|
|
|
dev_baseline_accuracies = measure_baselines(
|
|
|
|
dev_data, kb
|
|
|
|
)
|
|
|
|
|
|
|
|
logger.info("Dev Baseline Accuracies:")
|
|
|
|
logger.info(dev_baseline_accuracies.report_accuracy("random"))
|
|
|
|
logger.info(dev_baseline_accuracies.report_accuracy("prior"))
|
|
|
|
logger.info(dev_baseline_accuracies.report_accuracy("oracle"))
|
|
|
|
|
|
|
|
for itn in range(epochs):
|
|
|
|
random.shuffle(train_data)
|
|
|
|
losses = {}
|
|
|
|
batches = minibatch(train_data, size=compounding(4.0, 128.0, 1.001))
|
|
|
|
batchnr = 0
|
|
|
|
|
|
|
|
with nlp.disable_pipes(*other_pipes):
|
|
|
|
for batch in batches:
|
|
|
|
try:
|
|
|
|
docs, golds = zip(*batch)
|
|
|
|
nlp.update(
|
|
|
|
docs=docs,
|
|
|
|
golds=golds,
|
|
|
|
sgd=optimizer,
|
|
|
|
drop=dropout,
|
|
|
|
losses=losses,
|
|
|
|
)
|
|
|
|
batchnr += 1
|
|
|
|
except Exception as e:
|
|
|
|
logger.error("Error updating batch:" + str(e))
|
|
|
|
if batchnr > 0:
|
|
|
|
logging.info("Epoch {}, train loss {}".format(itn, round(losses["entity_linker"] / batchnr, 2)))
|
|
|
|
measure_performance(dev_data, kb, el_pipe)
|
|
|
|
|
|
|
|
# STEP 4: measure the performance of our trained pipe on an independent dev set
|
|
|
|
logger.info("STEP 4: performance measurement of Entity Linking pipe")
|
|
|
|
measure_performance(dev_data, kb, el_pipe)
|
|
|
|
|
|
|
|
# STEP 5: apply the EL pipe on a toy example
|
|
|
|
logger.info("STEP 5: applying Entity Linking to toy example")
|
2019-08-13 16:38:59 +03:00
|
|
|
run_el_toy_example(nlp=nlp)
|
|
|
|
|
|
|
|
if output_dir:
|
2019-09-13 18:03:57 +03:00
|
|
|
# STEP 6: write the NLP pipeline (including entity linker) to file
|
|
|
|
logger.info("STEP 6: Writing trained NLP to {}".format(nlp_output_dir))
|
|
|
|
nlp.to_disk(nlp_output_dir)
|
2019-08-13 16:38:59 +03:00
|
|
|
|
2019-09-13 18:03:57 +03:00
|
|
|
logger.info("Done!")
|
2019-08-13 16:38:59 +03:00
|
|
|
|
|
|
|
|
|
|
|
def check_kb(kb):
|
|
|
|
for mention in ("Bush", "Douglas Adams", "Homer", "Brazil", "China"):
|
|
|
|
candidates = kb.get_candidates(mention)
|
|
|
|
|
2019-09-13 18:03:57 +03:00
|
|
|
logger.info("generating candidates for " + mention + " :")
|
2019-08-13 16:38:59 +03:00
|
|
|
for c in candidates:
|
2019-09-13 18:03:57 +03:00
|
|
|
logger.info(" ".join[
|
|
|
|
str(c.prior_prob),
|
2019-08-13 16:38:59 +03:00
|
|
|
c.alias_,
|
|
|
|
"-->",
|
2019-09-13 18:03:57 +03:00
|
|
|
c.entity_ + " (freq=" + str(c.entity_freq) + ")"
|
|
|
|
])
|
2019-08-13 16:38:59 +03:00
|
|
|
|
|
|
|
|
|
|
|
def run_el_toy_example(nlp):
|
|
|
|
text = (
|
|
|
|
"In The Hitchhiker's Guide to the Galaxy, written by Douglas Adams, "
|
|
|
|
"Douglas reminds us to always bring our towel, even in China or Brazil. "
|
|
|
|
"The main character in Doug's novel is the man Arthur Dent, "
|
|
|
|
"but Dougledydoug doesn't write about George Washington or Homer Simpson."
|
|
|
|
)
|
|
|
|
doc = nlp(text)
|
2019-09-13 18:03:57 +03:00
|
|
|
logger.info(text)
|
2019-08-13 16:38:59 +03:00
|
|
|
for ent in doc.ents:
|
2019-09-13 18:03:57 +03:00
|
|
|
logger.info(" ".join(["ent", ent.text, ent.label_, ent.kb_id_]))
|
2019-08-13 16:38:59 +03:00
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
2019-09-13 18:03:57 +03:00
|
|
|
logging.basicConfig(level=logging.INFO, format=LOG_FORMAT)
|
2019-08-13 16:38:59 +03:00
|
|
|
plac.call(main)
|