Reduce mem usage in training Entity Linker (#4811)

* move nlp processing for el pipe to batch training instead of preprocessing

* adding dev eval back in, and limit in articles instead of entities

* use pipe whenever possible

* few more small doc changes

* access dev data through generator

* tqdm description

* small fixes

* update documentation
This commit is contained in:
Sofie Van Landeghem 2020-01-06 14:59:50 +01:00 committed by Matthew Honnibal
parent 6e9b61b49d
commit 7b96a5e10f
5 changed files with 200 additions and 230 deletions

View File

@ -7,16 +7,16 @@ Run `wikipedia_pretrain_kb.py`
* WikiData: get `latest-all.json.bz2` from https://dumps.wikimedia.org/wikidatawiki/entities/
* Wikipedia: get `enwiki-latest-pages-articles-multistream.xml.bz2` from https://dumps.wikimedia.org/enwiki/latest/ (or for any other language)
* You can set the filtering parameters for KB construction:
* `max_per_alias`: (max) number of candidate entities in the KB per alias/synonym
* `min_freq`: threshold of number of times an entity should occur in the corpus to be included in the KB
* `min_pair`: threshold of number of times an entity+alias combination should occur in the corpus to be included in the KB
* `max_per_alias` (`-a`): (max) number of candidate entities in the KB per alias/synonym
* `min_freq` (`-f`): threshold of number of times an entity should occur in the corpus to be included in the KB
* `min_pair` (`-c`): threshold of number of times an entity+alias combination should occur in the corpus to be included in the KB
* Further parameters to set:
* `descriptions_from_wikipedia`: whether to parse descriptions from Wikipedia (`True`) or Wikidata (`False`)
* `entity_vector_length`: length of the pre-trained entity description vectors
* `lang`: language for which to fetch Wikidata information (as the dump contains all languages)
* `descriptions_from_wikipedia` (`-wp`): whether to parse descriptions from Wikipedia (`True`) or Wikidata (`False`)
* `entity_vector_length` (`-v`): length of the pre-trained entity description vectors
* `lang` (`-la`): language for which to fetch Wikidata information (as the dump contains all languages)
Quick testing and rerunning:
* When trying out the pipeline for a quick test, set `limit_prior`, `limit_train` and/or `limit_wd` to read only parts of the dumps instead of everything.
* When trying out the pipeline for a quick test, set `limit_prior` (`-lp`), `limit_train` (`-lt`) and/or `limit_wd` (`-lw`) to read only parts of the dumps instead of everything.
* If you only want to (re)run certain parts of the pipeline, just remove the corresponding files and they will be recalculated or reparsed.
@ -24,11 +24,13 @@ Quick testing and rerunning:
Run `wikidata_train_entity_linker.py`
* This takes the **KB directory** produced by Step 1, and trains an **Entity Linking model**
* Specify the output directory (`-o`) in which the final, trained model will be saved
* You can set the learning parameters for the EL training:
* `epochs`: number of training iterations
* `dropout`: dropout rate
* `lr`: learning rate
* `l2`: L2 regularization
* Specify the number of training and dev testing entities with `train_inst` and `dev_inst` respectively
* `epochs` (`-e`): number of training iterations
* `dropout` (`-p`): dropout rate
* `lr` (`-n`): learning rate
* `l2` (`-r`): L2 regularization
* Specify the number of training and dev testing articles with `train_articles` (`-t`) and `dev_articles` (`-d`) respectively
* If not specified, the full dataset will be processed - this may take a LONG time !
* Further parameters to set:
* `labels_discard`: NER label types to discard during training
* `labels_discard` (`-l`): NER label types to discard during training

View File

@ -1,6 +1,8 @@
# coding: utf-8
from __future__ import unicode_literals
import logging
import random
from tqdm import tqdm
from collections import defaultdict
@ -92,133 +94,110 @@ class BaselineResults(object):
self.random.update_metrics(ent_label, true_entity, random_candidate)
def measure_performance(dev_data, kb, el_pipe, baseline=True, context=True):
if baseline:
baseline_accuracies, counts = measure_baselines(dev_data, kb)
logger.info("Counts: {}".format({k: v for k, v in sorted(counts.items())}))
logger.info(baseline_accuracies.report_performance("random"))
logger.info(baseline_accuracies.report_performance("prior"))
logger.info(baseline_accuracies.report_performance("oracle"))
def measure_performance(dev_data, kb, el_pipe, baseline=True, context=True, dev_limit=None):
counts = dict()
baseline_results = BaselineResults()
context_results = EvaluationResults()
combo_results = EvaluationResults()
if context:
# using only context
el_pipe.cfg["incl_context"] = True
el_pipe.cfg["incl_prior"] = False
results = get_eval_results(dev_data, el_pipe)
logger.info(results.report_metrics("context only"))
# measuring combined accuracy (prior + context)
el_pipe.cfg["incl_context"] = True
el_pipe.cfg["incl_prior"] = True
results = get_eval_results(dev_data, el_pipe)
logger.info(results.report_metrics("context and prior"))
def get_eval_results(data, el_pipe=None):
"""
Evaluate the ent.kb_id_ annotations against the gold standard.
Only evaluate entities that overlap between gold and NER, to isolate the performance of the NEL.
If the docs in the data require further processing with an entity linker, set el_pipe.
"""
docs = []
golds = []
for d, g in tqdm(data, leave=False):
if len(d) > 0:
golds.append(g)
if el_pipe is not None:
docs.append(el_pipe(d))
else:
docs.append(d)
results = EvaluationResults()
for doc, gold in zip(docs, golds):
try:
correct_entries_per_article = dict()
for doc, gold in tqdm(dev_data, total=dev_limit, leave=False, desc='Processing dev data'):
if len(doc) > 0:
correct_ents = dict()
for entity, kb_dict in gold.links.items():
start, end = entity
for gold_kb, value in kb_dict.items():
if value:
# only evaluating on positive examples
offset = _offset(start, end)
correct_entries_per_article[offset] = gold_kb
correct_ents[offset] = gold_kb
for ent in doc.ents:
ent_label = ent.label_
pred_entity = ent.kb_id_
start = ent.start_char
end = ent.end_char
offset = _offset(start, end)
gold_entity = correct_entries_per_article.get(offset, None)
# the gold annotations are not complete so we can't evaluate missing annotations as 'wrong'
if gold_entity is not None:
results.update_metrics(ent_label, gold_entity, pred_entity)
if baseline:
_add_baseline(baseline_results, counts, doc, correct_ents, kb)
except Exception as e:
logging.error("Error assessing accuracy " + str(e))
if context:
# using only context
el_pipe.cfg["incl_context"] = True
el_pipe.cfg["incl_prior"] = False
_add_eval_result(context_results, doc, correct_ents, el_pipe)
return results
# measuring combined accuracy (prior + context)
el_pipe.cfg["incl_context"] = True
el_pipe.cfg["incl_prior"] = True
_add_eval_result(combo_results, doc, correct_ents, el_pipe)
if baseline:
logger.info("Counts: {}".format({k: v for k, v in sorted(counts.items())}))
logger.info(baseline_results.report_performance("random"))
logger.info(baseline_results.report_performance("prior"))
logger.info(baseline_results.report_performance("oracle"))
if context:
logger.info(context_results.report_metrics("context only"))
logger.info(combo_results.report_metrics("context and prior"))
def measure_baselines(data, kb):
def _add_eval_result(results, doc, correct_ents, el_pipe):
"""
Measure 3 performance baselines: random selection, prior probabilities, and 'oracle' prediction for upper bound.
Evaluate the ent.kb_id_ annotations against the gold standard.
Only evaluate entities that overlap between gold and NER, to isolate the performance of the NEL.
Also return a dictionary of counts by entity label.
"""
counts_d = dict()
baseline_results = BaselineResults()
docs = [d for d, g in data if len(d) > 0]
golds = [g for d, g in data if len(d) > 0]
for doc, gold in zip(docs, golds):
correct_entries_per_article = dict()
for entity, kb_dict in gold.links.items():
start, end = entity
for gold_kb, value in kb_dict.items():
# only evaluating on positive examples
if value:
offset = _offset(start, end)
correct_entries_per_article[offset] = gold_kb
try:
doc = el_pipe(doc)
for ent in doc.ents:
ent_label = ent.label_
start = ent.start_char
end = ent.end_char
offset = _offset(start, end)
gold_entity = correct_entries_per_article.get(offset, None)
gold_entity = correct_ents.get(offset, None)
# the gold annotations are not complete so we can't evaluate missing annotations as 'wrong'
if gold_entity is not None:
candidates = kb.get_candidates(ent.text)
oracle_candidate = ""
prior_candidate = ""
random_candidate = ""
if candidates:
scores = []
pred_entity = ent.kb_id_
results.update_metrics(ent_label, gold_entity, pred_entity)
for c in candidates:
scores.append(c.prior_prob)
if c.entity_ == gold_entity:
oracle_candidate = c.entity_
except Exception as e:
logging.error("Error assessing accuracy " + str(e))
best_index = scores.index(max(scores))
prior_candidate = candidates[best_index].entity_
random_candidate = random.choice(candidates).entity_
current_count = counts_d.get(ent_label, 0)
counts_d[ent_label] = current_count+1
def _add_baseline(baseline_results, counts, doc, correct_ents, kb):
"""
Measure 3 performance baselines: random selection, prior probabilities, and 'oracle' prediction for upper bound.
Only evaluate entities that overlap between gold and NER, to isolate the performance of the NEL.
"""
for ent in doc.ents:
ent_label = ent.label_
start = ent.start_char
end = ent.end_char
offset = _offset(start, end)
gold_entity = correct_ents.get(offset, None)
baseline_results.update_baselines(
gold_entity,
ent_label,
random_candidate,
prior_candidate,
oracle_candidate,
)
# the gold annotations are not complete so we can't evaluate missing annotations as 'wrong'
if gold_entity is not None:
candidates = kb.get_candidates(ent.text)
oracle_candidate = ""
prior_candidate = ""
random_candidate = ""
if candidates:
scores = []
return baseline_results, counts_d
for c in candidates:
scores.append(c.prior_prob)
if c.entity_ == gold_entity:
oracle_candidate = c.entity_
best_index = scores.index(max(scores))
prior_candidate = candidates[best_index].entity_
random_candidate = random.choice(candidates).entity_
current_count = counts.get(ent_label, 0)
counts[ent_label] = current_count+1
baseline_results.update_baselines(
gold_entity,
ent_label,
random_candidate,
prior_candidate,
oracle_candidate,
)
def _offset(start, end):

View File

@ -1,5 +1,5 @@
# coding: utf-8
"""Script to take a previously created Knowledge Base and train an entity linking
"""Script that takes a previously created Knowledge Base and trains an entity linking
pipeline. The provided KB directory should hold the kb, the original nlp object and
its vocab used to create the KB, and a few auxiliary files such as the entity definitions,
as created by the script `wikidata_create_kb`.
@ -14,6 +14,7 @@ import logging
import spacy
from pathlib import Path
import plac
from tqdm import tqdm
from bin.wiki_entity_linking import wikipedia_processor
from bin.wiki_entity_linking import TRAINING_DATA_FILE, KB_MODEL_DIR, KB_FILE, LOG_FORMAT, OUTPUT_MODEL_DIR
@ -33,8 +34,8 @@ logger = logging.getLogger(__name__)
dropout=("Dropout to prevent overfitting (default 0.5)", "option", "p", float),
lr=("Learning rate (default 0.005)", "option", "n", float),
l2=("L2 regularization", "option", "r", float),
train_inst=("# training instances (default 90% of all)", "option", "t", int),
dev_inst=("# test instances (default 10% of all)", "option", "d", int),
train_articles=("# training articles (default 90% of all)", "option", "t", int),
dev_articles=("# dev test articles (default 10% of all)", "option", "d", int),
labels_discard=("NER labels to discard (default None)", "option", "l", str),
)
def main(
@ -45,10 +46,13 @@ def main(
dropout=0.5,
lr=0.005,
l2=1e-6,
train_inst=None,
dev_inst=None,
train_articles=None,
dev_articles=None,
labels_discard=None
):
if not output_dir:
logger.warning("No output dir specified so no results will be written, are you sure about this ?")
logger.info("Creating Entity Linker with Wikipedia and WikiData")
output_dir = Path(output_dir) if output_dir else dir_kb
@ -64,44 +68,33 @@ def main(
# STEP 1 : load the NLP object
logger.info("STEP 1a: Loading model from {}".format(nlp_dir))
nlp = spacy.load(nlp_dir)
logger.info("STEP 1b: Loading KB from {}".format(kb_path))
kb = read_kb(nlp, kb_path)
logger.info("Original NLP pipeline has following pipeline components: {}".format(nlp.pipe_names))
# check that there is a NER component in the pipeline
if "ner" not in nlp.pipe_names:
raise ValueError("The `nlp` object should have a pretrained `ner` component.")
# STEP 2: read the training dataset previously created from WP
logger.info("STEP 2: Reading training dataset from {}".format(training_path))
logger.info("STEP 1b: Loading KB from {}".format(kb_path))
kb = read_kb(nlp, kb_path)
# STEP 2: read the training dataset previously created from WP
logger.info("STEP 2: Reading training & dev dataset from {}".format(training_path))
train_indices, dev_indices = wikipedia_processor.read_training_indices(training_path)
logger.info("Training set has {} articles, limit set to roughly {} articles per epoch"
.format(len(train_indices), train_articles if train_articles else "all"))
logger.info("Dev set has {} articles, limit set to rougly {} articles for evaluation"
.format(len(dev_indices), dev_articles if dev_articles else "all"))
if dev_articles:
dev_indices = dev_indices[0:dev_articles]
# STEP 3: create and train an entity linking pipe
logger.info("STEP 3: Creating and training an Entity Linking pipe for {} epochs".format(epochs))
if labels_discard:
labels_discard = [x.strip() for x in labels_discard.split(",")]
logger.info("Discarding {} NER types: {}".format(len(labels_discard), labels_discard))
else:
labels_discard = []
train_data = wikipedia_processor.read_training(
nlp=nlp,
entity_file_path=training_path,
dev=False,
limit=train_inst,
kb=kb,
labels_discard=labels_discard
)
# for testing, get all pos instances (independently of KB)
dev_data = wikipedia_processor.read_training(
nlp=nlp,
entity_file_path=training_path,
dev=True,
limit=dev_inst,
kb=None,
labels_discard=labels_discard
)
# STEP 3: create and train an entity linking pipe
logger.info("STEP 3: Creating and training an Entity Linking pipe")
el_pipe = nlp.create_pipe(
name="entity_linker", config={"pretrained_vectors": nlp.vocab.vectors.name,
"labels_discard": labels_discard}
@ -115,80 +108,65 @@ def main(
optimizer.learn_rate = lr
optimizer.L2 = l2
logger.info("Training on {} articles".format(len(train_data)))
logger.info("Dev testing on {} articles".format(len(dev_data)))
# baseline performance on dev data
logger.info("Dev Baseline Accuracies:")
measure_performance(dev_data, kb, el_pipe, baseline=True, context=False)
dev_data = wikipedia_processor.read_el_docs_golds(nlp=nlp, entity_file_path=training_path,
dev=True, line_ids=dev_indices,
kb=kb, labels_discard=labels_discard)
measure_performance(dev_data, kb, el_pipe, baseline=True, context=False, dev_limit=len(dev_indices))
for itn in range(epochs):
random.shuffle(train_data)
random.shuffle(train_indices)
losses = {}
batches = minibatch(train_data, size=compounding(4.0, 128.0, 1.001))
batches = minibatch(train_indices, size=compounding(8.0, 128.0, 1.001))
batchnr = 0
articles_processed = 0
with nlp.disable_pipes(*other_pipes):
# we either process the whole training file, or just a part each epoch
bar_total = len(train_indices)
if train_articles:
bar_total = train_articles
with tqdm(total=bar_total, leave=False, desc='Epoch ' + str(itn)) as pbar:
for batch in batches:
try:
docs, golds = zip(*batch)
nlp.update(
docs=docs,
golds=golds,
sgd=optimizer,
drop=dropout,
losses=losses,
)
batchnr += 1
except Exception as e:
logger.error("Error updating batch:" + str(e))
if not train_articles or articles_processed < train_articles:
with nlp.disable_pipes("entity_linker"):
train_batch = wikipedia_processor.read_el_docs_golds(nlp=nlp, entity_file_path=training_path,
dev=False, line_ids=batch,
kb=kb, labels_discard=labels_discard)
docs, golds = zip(*train_batch)
try:
with nlp.disable_pipes(*other_pipes):
nlp.update(
docs=docs,
golds=golds,
sgd=optimizer,
drop=dropout,
losses=losses,
)
batchnr += 1
articles_processed += len(docs)
pbar.update(len(docs))
except Exception as e:
logger.error("Error updating batch:" + str(e))
if batchnr > 0:
logging.info("Epoch {}, train loss {}".format(itn, round(losses["entity_linker"] / batchnr, 2)))
measure_performance(dev_data, kb, el_pipe, baseline=False, context=True)
# STEP 4: measure the performance of our trained pipe on an independent dev set
logger.info("STEP 4: Final performance measurement of Entity Linking pipe")
measure_performance(dev_data, kb, el_pipe)
# STEP 5: apply the EL pipe on a toy example
logger.info("STEP 5: Applying Entity Linking to toy example")
run_el_toy_example(nlp=nlp)
logging.info("Epoch {} trained on {} articles, train loss {}"
.format(itn, articles_processed, round(losses["entity_linker"] / batchnr, 2)))
# re-read the dev_data (data is returned as a generator)
dev_data = wikipedia_processor.read_el_docs_golds(nlp=nlp, entity_file_path=training_path,
dev=True, line_ids=dev_indices,
kb=kb, labels_discard=labels_discard)
measure_performance(dev_data, kb, el_pipe, baseline=False, context=True, dev_limit=len(dev_indices))
if output_dir:
# STEP 6: write the NLP pipeline (now including an EL model) to file
logger.info("STEP 6: Writing trained NLP to {}".format(nlp_output_dir))
# STEP 4: write the NLP pipeline (now including an EL model) to file
logger.info("Final NLP pipeline has following pipeline components: {}".format(nlp.pipe_names))
logger.info("STEP 4: Writing trained NLP to {}".format(nlp_output_dir))
nlp.to_disk(nlp_output_dir)
logger.info("Done!")
def check_kb(kb):
for mention in ("Bush", "Douglas Adams", "Homer", "Brazil", "China"):
candidates = kb.get_candidates(mention)
logger.info("generating candidates for " + mention + " :")
for c in candidates:
logger.info(" ".join[
str(c.prior_prob),
c.alias_,
"-->",
c.entity_ + " (freq=" + str(c.entity_freq) + ")"
])
def run_el_toy_example(nlp):
text = (
"In The Hitchhiker's Guide to the Galaxy, written by Douglas Adams, "
"Douglas reminds us to always bring our towel, even in China or Brazil. "
"The main character in Doug's novel is the man Arthur Dent, "
"but Dougledydoug doesn't write about George Washington or Homer Simpson."
)
doc = nlp(text)
logger.info(text)
for ent in doc.ents:
logger.info(" ".join(["ent", ent.text, ent.label_, ent.kb_id_]))
if __name__ == "__main__":
logging.basicConfig(level=logging.INFO, format=LOG_FORMAT)
plac.call(main)

View File

@ -6,9 +6,6 @@ import bz2
import logging
import random
import json
from tqdm import tqdm
from functools import partial
from spacy.gold import GoldParse
from bin.wiki_entity_linking import wiki_io as io
@ -454,25 +451,40 @@ def _write_training_entities(outputfile, article_id, clean_text, entities):
outputfile.write(line)
def read_training(nlp, entity_file_path, dev, limit, kb, labels_discard=None):
""" This method provides training examples that correspond to the entity annotations found by the nlp object.
def read_training_indices(entity_file_path):
""" This method creates two lists of indices into the training file: one with indices for the
training examples, and one for the dev examples."""
train_indices = []
dev_indices = []
with entity_file_path.open("r", encoding="utf8") as file:
for i, line in enumerate(file):
example = json.loads(line)
article_id = example["article_id"]
clean_text = example["clean_text"]
if is_valid_article(clean_text):
if is_dev(article_id):
dev_indices.append(i)
else:
train_indices.append(i)
return train_indices, dev_indices
def read_el_docs_golds(nlp, entity_file_path, dev, line_ids, kb, labels_discard=None):
""" This method provides training/dev examples that correspond to the entity annotations found by the nlp object.
For training, it will include both positive and negative examples by using the candidate generator from the kb.
For testing (kb=None), it will include all positive examples only."""
if not labels_discard:
labels_discard = []
data = []
num_entities = 0
get_gold_parse = partial(
_get_gold_parse, dev=dev, kb=kb, labels_discard=labels_discard
)
texts = []
entities_list = []
logger.info(
"Reading {} data with limit {}".format("dev" if dev else "train", limit)
)
with entity_file_path.open("r", encoding="utf8") as file:
with tqdm(total=limit, leave=False) as pbar:
for i, line in enumerate(file):
for i, line in enumerate(file):
if i in line_ids:
example = json.loads(line)
article_id = example["article_id"]
clean_text = example["clean_text"]
@ -481,16 +493,15 @@ def read_training(nlp, entity_file_path, dev, limit, kb, labels_discard=None):
if dev != is_dev(article_id) or not is_valid_article(clean_text):
continue
doc = nlp(clean_text)
gold = get_gold_parse(doc, entities)
if gold and len(gold.links) > 0:
data.append((doc, gold))
num_entities += len(gold.links)
pbar.update(len(gold.links))
if limit and num_entities >= limit:
break
logger.info("Read {} entities in {} articles".format(num_entities, len(data)))
return data
texts.append(clean_text)
entities_list.append(entities)
docs = nlp.pipe(texts, batch_size=50)
for doc, entities in zip(docs, entities_list):
gold = _get_gold_parse(doc, entities, dev=dev, kb=kb, labels_discard=labels_discard)
if gold and len(gold.links) > 0:
yield doc, gold
def _get_gold_parse(doc, entities, dev, kb, labels_discard):

View File

@ -1308,7 +1308,7 @@ class EntityLinker(Pipe):
for i, doc in enumerate(docs):
if len(doc) > 0:
# Looping through each sentence and each entity
# This may go wrong if there are entities across sentences - because they might not get a KB ID
# This may go wrong if there are entities across sentences - which shouldn't happen normally.
for sent in doc.sents:
sent_doc = sent.as_doc()
# currently, the context is the same for each entity in a sentence (should be refined)