mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-12 10:16:27 +03:00
Reduce mem usage in training Entity Linker (#4811)
* move nlp processing for el pipe to batch training instead of preprocessing * adding dev eval back in, and limit in articles instead of entities * use pipe whenever possible * few more small doc changes * access dev data through generator * tqdm description * small fixes * update documentation
This commit is contained in:
parent
6e9b61b49d
commit
7b96a5e10f
|
@ -7,16 +7,16 @@ Run `wikipedia_pretrain_kb.py`
|
||||||
* WikiData: get `latest-all.json.bz2` from https://dumps.wikimedia.org/wikidatawiki/entities/
|
* WikiData: get `latest-all.json.bz2` from https://dumps.wikimedia.org/wikidatawiki/entities/
|
||||||
* Wikipedia: get `enwiki-latest-pages-articles-multistream.xml.bz2` from https://dumps.wikimedia.org/enwiki/latest/ (or for any other language)
|
* Wikipedia: get `enwiki-latest-pages-articles-multistream.xml.bz2` from https://dumps.wikimedia.org/enwiki/latest/ (or for any other language)
|
||||||
* You can set the filtering parameters for KB construction:
|
* You can set the filtering parameters for KB construction:
|
||||||
* `max_per_alias`: (max) number of candidate entities in the KB per alias/synonym
|
* `max_per_alias` (`-a`): (max) number of candidate entities in the KB per alias/synonym
|
||||||
* `min_freq`: threshold of number of times an entity should occur in the corpus to be included in the KB
|
* `min_freq` (`-f`): threshold of number of times an entity should occur in the corpus to be included in the KB
|
||||||
* `min_pair`: threshold of number of times an entity+alias combination should occur in the corpus to be included in the KB
|
* `min_pair` (`-c`): threshold of number of times an entity+alias combination should occur in the corpus to be included in the KB
|
||||||
* Further parameters to set:
|
* Further parameters to set:
|
||||||
* `descriptions_from_wikipedia`: whether to parse descriptions from Wikipedia (`True`) or Wikidata (`False`)
|
* `descriptions_from_wikipedia` (`-wp`): whether to parse descriptions from Wikipedia (`True`) or Wikidata (`False`)
|
||||||
* `entity_vector_length`: length of the pre-trained entity description vectors
|
* `entity_vector_length` (`-v`): length of the pre-trained entity description vectors
|
||||||
* `lang`: language for which to fetch Wikidata information (as the dump contains all languages)
|
* `lang` (`-la`): language for which to fetch Wikidata information (as the dump contains all languages)
|
||||||
|
|
||||||
Quick testing and rerunning:
|
Quick testing and rerunning:
|
||||||
* When trying out the pipeline for a quick test, set `limit_prior`, `limit_train` and/or `limit_wd` to read only parts of the dumps instead of everything.
|
* When trying out the pipeline for a quick test, set `limit_prior` (`-lp`), `limit_train` (`-lt`) and/or `limit_wd` (`-lw`) to read only parts of the dumps instead of everything.
|
||||||
* If you only want to (re)run certain parts of the pipeline, just remove the corresponding files and they will be recalculated or reparsed.
|
* If you only want to (re)run certain parts of the pipeline, just remove the corresponding files and they will be recalculated or reparsed.
|
||||||
|
|
||||||
|
|
||||||
|
@ -24,11 +24,13 @@ Quick testing and rerunning:
|
||||||
|
|
||||||
Run `wikidata_train_entity_linker.py`
|
Run `wikidata_train_entity_linker.py`
|
||||||
* This takes the **KB directory** produced by Step 1, and trains an **Entity Linking model**
|
* This takes the **KB directory** produced by Step 1, and trains an **Entity Linking model**
|
||||||
|
* Specify the output directory (`-o`) in which the final, trained model will be saved
|
||||||
* You can set the learning parameters for the EL training:
|
* You can set the learning parameters for the EL training:
|
||||||
* `epochs`: number of training iterations
|
* `epochs` (`-e`): number of training iterations
|
||||||
* `dropout`: dropout rate
|
* `dropout` (`-p`): dropout rate
|
||||||
* `lr`: learning rate
|
* `lr` (`-n`): learning rate
|
||||||
* `l2`: L2 regularization
|
* `l2` (`-r`): L2 regularization
|
||||||
* Specify the number of training and dev testing entities with `train_inst` and `dev_inst` respectively
|
* Specify the number of training and dev testing articles with `train_articles` (`-t`) and `dev_articles` (`-d`) respectively
|
||||||
|
* If not specified, the full dataset will be processed - this may take a LONG time !
|
||||||
* Further parameters to set:
|
* Further parameters to set:
|
||||||
* `labels_discard`: NER label types to discard during training
|
* `labels_discard` (`-l`): NER label types to discard during training
|
||||||
|
|
|
@ -1,6 +1,8 @@
|
||||||
|
# coding: utf-8
|
||||||
|
from __future__ import unicode_literals
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import random
|
import random
|
||||||
|
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
|
|
||||||
|
@ -92,133 +94,110 @@ class BaselineResults(object):
|
||||||
self.random.update_metrics(ent_label, true_entity, random_candidate)
|
self.random.update_metrics(ent_label, true_entity, random_candidate)
|
||||||
|
|
||||||
|
|
||||||
def measure_performance(dev_data, kb, el_pipe, baseline=True, context=True):
|
def measure_performance(dev_data, kb, el_pipe, baseline=True, context=True, dev_limit=None):
|
||||||
if baseline:
|
counts = dict()
|
||||||
baseline_accuracies, counts = measure_baselines(dev_data, kb)
|
baseline_results = BaselineResults()
|
||||||
logger.info("Counts: {}".format({k: v for k, v in sorted(counts.items())}))
|
context_results = EvaluationResults()
|
||||||
logger.info(baseline_accuracies.report_performance("random"))
|
combo_results = EvaluationResults()
|
||||||
logger.info(baseline_accuracies.report_performance("prior"))
|
|
||||||
logger.info(baseline_accuracies.report_performance("oracle"))
|
|
||||||
|
|
||||||
if context:
|
for doc, gold in tqdm(dev_data, total=dev_limit, leave=False, desc='Processing dev data'):
|
||||||
# using only context
|
if len(doc) > 0:
|
||||||
el_pipe.cfg["incl_context"] = True
|
correct_ents = dict()
|
||||||
el_pipe.cfg["incl_prior"] = False
|
|
||||||
results = get_eval_results(dev_data, el_pipe)
|
|
||||||
logger.info(results.report_metrics("context only"))
|
|
||||||
|
|
||||||
# measuring combined accuracy (prior + context)
|
|
||||||
el_pipe.cfg["incl_context"] = True
|
|
||||||
el_pipe.cfg["incl_prior"] = True
|
|
||||||
results = get_eval_results(dev_data, el_pipe)
|
|
||||||
logger.info(results.report_metrics("context and prior"))
|
|
||||||
|
|
||||||
|
|
||||||
def get_eval_results(data, el_pipe=None):
|
|
||||||
"""
|
|
||||||
Evaluate the ent.kb_id_ annotations against the gold standard.
|
|
||||||
Only evaluate entities that overlap between gold and NER, to isolate the performance of the NEL.
|
|
||||||
If the docs in the data require further processing with an entity linker, set el_pipe.
|
|
||||||
"""
|
|
||||||
docs = []
|
|
||||||
golds = []
|
|
||||||
for d, g in tqdm(data, leave=False):
|
|
||||||
if len(d) > 0:
|
|
||||||
golds.append(g)
|
|
||||||
if el_pipe is not None:
|
|
||||||
docs.append(el_pipe(d))
|
|
||||||
else:
|
|
||||||
docs.append(d)
|
|
||||||
|
|
||||||
results = EvaluationResults()
|
|
||||||
for doc, gold in zip(docs, golds):
|
|
||||||
try:
|
|
||||||
correct_entries_per_article = dict()
|
|
||||||
for entity, kb_dict in gold.links.items():
|
for entity, kb_dict in gold.links.items():
|
||||||
start, end = entity
|
start, end = entity
|
||||||
for gold_kb, value in kb_dict.items():
|
for gold_kb, value in kb_dict.items():
|
||||||
if value:
|
if value:
|
||||||
# only evaluating on positive examples
|
# only evaluating on positive examples
|
||||||
offset = _offset(start, end)
|
offset = _offset(start, end)
|
||||||
correct_entries_per_article[offset] = gold_kb
|
correct_ents[offset] = gold_kb
|
||||||
|
|
||||||
for ent in doc.ents:
|
if baseline:
|
||||||
ent_label = ent.label_
|
_add_baseline(baseline_results, counts, doc, correct_ents, kb)
|
||||||
pred_entity = ent.kb_id_
|
|
||||||
start = ent.start_char
|
|
||||||
end = ent.end_char
|
|
||||||
offset = _offset(start, end)
|
|
||||||
gold_entity = correct_entries_per_article.get(offset, None)
|
|
||||||
# the gold annotations are not complete so we can't evaluate missing annotations as 'wrong'
|
|
||||||
if gold_entity is not None:
|
|
||||||
results.update_metrics(ent_label, gold_entity, pred_entity)
|
|
||||||
|
|
||||||
except Exception as e:
|
if context:
|
||||||
logging.error("Error assessing accuracy " + str(e))
|
# using only context
|
||||||
|
el_pipe.cfg["incl_context"] = True
|
||||||
|
el_pipe.cfg["incl_prior"] = False
|
||||||
|
_add_eval_result(context_results, doc, correct_ents, el_pipe)
|
||||||
|
|
||||||
return results
|
# measuring combined accuracy (prior + context)
|
||||||
|
el_pipe.cfg["incl_context"] = True
|
||||||
|
el_pipe.cfg["incl_prior"] = True
|
||||||
|
_add_eval_result(combo_results, doc, correct_ents, el_pipe)
|
||||||
|
|
||||||
|
if baseline:
|
||||||
|
logger.info("Counts: {}".format({k: v for k, v in sorted(counts.items())}))
|
||||||
|
logger.info(baseline_results.report_performance("random"))
|
||||||
|
logger.info(baseline_results.report_performance("prior"))
|
||||||
|
logger.info(baseline_results.report_performance("oracle"))
|
||||||
|
|
||||||
|
if context:
|
||||||
|
logger.info(context_results.report_metrics("context only"))
|
||||||
|
logger.info(combo_results.report_metrics("context and prior"))
|
||||||
|
|
||||||
|
|
||||||
def measure_baselines(data, kb):
|
def _add_eval_result(results, doc, correct_ents, el_pipe):
|
||||||
"""
|
"""
|
||||||
Measure 3 performance baselines: random selection, prior probabilities, and 'oracle' prediction for upper bound.
|
Evaluate the ent.kb_id_ annotations against the gold standard.
|
||||||
Only evaluate entities that overlap between gold and NER, to isolate the performance of the NEL.
|
Only evaluate entities that overlap between gold and NER, to isolate the performance of the NEL.
|
||||||
Also return a dictionary of counts by entity label.
|
|
||||||
"""
|
"""
|
||||||
counts_d = dict()
|
try:
|
||||||
|
doc = el_pipe(doc)
|
||||||
baseline_results = BaselineResults()
|
|
||||||
|
|
||||||
docs = [d for d, g in data if len(d) > 0]
|
|
||||||
golds = [g for d, g in data if len(d) > 0]
|
|
||||||
|
|
||||||
for doc, gold in zip(docs, golds):
|
|
||||||
correct_entries_per_article = dict()
|
|
||||||
for entity, kb_dict in gold.links.items():
|
|
||||||
start, end = entity
|
|
||||||
for gold_kb, value in kb_dict.items():
|
|
||||||
# only evaluating on positive examples
|
|
||||||
if value:
|
|
||||||
offset = _offset(start, end)
|
|
||||||
correct_entries_per_article[offset] = gold_kb
|
|
||||||
|
|
||||||
for ent in doc.ents:
|
for ent in doc.ents:
|
||||||
ent_label = ent.label_
|
ent_label = ent.label_
|
||||||
start = ent.start_char
|
start = ent.start_char
|
||||||
end = ent.end_char
|
end = ent.end_char
|
||||||
offset = _offset(start, end)
|
offset = _offset(start, end)
|
||||||
gold_entity = correct_entries_per_article.get(offset, None)
|
gold_entity = correct_ents.get(offset, None)
|
||||||
|
|
||||||
# the gold annotations are not complete so we can't evaluate missing annotations as 'wrong'
|
# the gold annotations are not complete so we can't evaluate missing annotations as 'wrong'
|
||||||
if gold_entity is not None:
|
if gold_entity is not None:
|
||||||
candidates = kb.get_candidates(ent.text)
|
pred_entity = ent.kb_id_
|
||||||
oracle_candidate = ""
|
results.update_metrics(ent_label, gold_entity, pred_entity)
|
||||||
prior_candidate = ""
|
|
||||||
random_candidate = ""
|
|
||||||
if candidates:
|
|
||||||
scores = []
|
|
||||||
|
|
||||||
for c in candidates:
|
except Exception as e:
|
||||||
scores.append(c.prior_prob)
|
logging.error("Error assessing accuracy " + str(e))
|
||||||
if c.entity_ == gold_entity:
|
|
||||||
oracle_candidate = c.entity_
|
|
||||||
|
|
||||||
best_index = scores.index(max(scores))
|
|
||||||
prior_candidate = candidates[best_index].entity_
|
|
||||||
random_candidate = random.choice(candidates).entity_
|
|
||||||
|
|
||||||
current_count = counts_d.get(ent_label, 0)
|
def _add_baseline(baseline_results, counts, doc, correct_ents, kb):
|
||||||
counts_d[ent_label] = current_count+1
|
"""
|
||||||
|
Measure 3 performance baselines: random selection, prior probabilities, and 'oracle' prediction for upper bound.
|
||||||
|
Only evaluate entities that overlap between gold and NER, to isolate the performance of the NEL.
|
||||||
|
"""
|
||||||
|
for ent in doc.ents:
|
||||||
|
ent_label = ent.label_
|
||||||
|
start = ent.start_char
|
||||||
|
end = ent.end_char
|
||||||
|
offset = _offset(start, end)
|
||||||
|
gold_entity = correct_ents.get(offset, None)
|
||||||
|
|
||||||
baseline_results.update_baselines(
|
# the gold annotations are not complete so we can't evaluate missing annotations as 'wrong'
|
||||||
gold_entity,
|
if gold_entity is not None:
|
||||||
ent_label,
|
candidates = kb.get_candidates(ent.text)
|
||||||
random_candidate,
|
oracle_candidate = ""
|
||||||
prior_candidate,
|
prior_candidate = ""
|
||||||
oracle_candidate,
|
random_candidate = ""
|
||||||
)
|
if candidates:
|
||||||
|
scores = []
|
||||||
|
|
||||||
return baseline_results, counts_d
|
for c in candidates:
|
||||||
|
scores.append(c.prior_prob)
|
||||||
|
if c.entity_ == gold_entity:
|
||||||
|
oracle_candidate = c.entity_
|
||||||
|
|
||||||
|
best_index = scores.index(max(scores))
|
||||||
|
prior_candidate = candidates[best_index].entity_
|
||||||
|
random_candidate = random.choice(candidates).entity_
|
||||||
|
|
||||||
|
current_count = counts.get(ent_label, 0)
|
||||||
|
counts[ent_label] = current_count+1
|
||||||
|
|
||||||
|
baseline_results.update_baselines(
|
||||||
|
gold_entity,
|
||||||
|
ent_label,
|
||||||
|
random_candidate,
|
||||||
|
prior_candidate,
|
||||||
|
oracle_candidate,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _offset(start, end):
|
def _offset(start, end):
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
# coding: utf-8
|
# coding: utf-8
|
||||||
"""Script to take a previously created Knowledge Base and train an entity linking
|
"""Script that takes a previously created Knowledge Base and trains an entity linking
|
||||||
pipeline. The provided KB directory should hold the kb, the original nlp object and
|
pipeline. The provided KB directory should hold the kb, the original nlp object and
|
||||||
its vocab used to create the KB, and a few auxiliary files such as the entity definitions,
|
its vocab used to create the KB, and a few auxiliary files such as the entity definitions,
|
||||||
as created by the script `wikidata_create_kb`.
|
as created by the script `wikidata_create_kb`.
|
||||||
|
@ -14,6 +14,7 @@ import logging
|
||||||
import spacy
|
import spacy
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import plac
|
import plac
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
from bin.wiki_entity_linking import wikipedia_processor
|
from bin.wiki_entity_linking import wikipedia_processor
|
||||||
from bin.wiki_entity_linking import TRAINING_DATA_FILE, KB_MODEL_DIR, KB_FILE, LOG_FORMAT, OUTPUT_MODEL_DIR
|
from bin.wiki_entity_linking import TRAINING_DATA_FILE, KB_MODEL_DIR, KB_FILE, LOG_FORMAT, OUTPUT_MODEL_DIR
|
||||||
|
@ -33,8 +34,8 @@ logger = logging.getLogger(__name__)
|
||||||
dropout=("Dropout to prevent overfitting (default 0.5)", "option", "p", float),
|
dropout=("Dropout to prevent overfitting (default 0.5)", "option", "p", float),
|
||||||
lr=("Learning rate (default 0.005)", "option", "n", float),
|
lr=("Learning rate (default 0.005)", "option", "n", float),
|
||||||
l2=("L2 regularization", "option", "r", float),
|
l2=("L2 regularization", "option", "r", float),
|
||||||
train_inst=("# training instances (default 90% of all)", "option", "t", int),
|
train_articles=("# training articles (default 90% of all)", "option", "t", int),
|
||||||
dev_inst=("# test instances (default 10% of all)", "option", "d", int),
|
dev_articles=("# dev test articles (default 10% of all)", "option", "d", int),
|
||||||
labels_discard=("NER labels to discard (default None)", "option", "l", str),
|
labels_discard=("NER labels to discard (default None)", "option", "l", str),
|
||||||
)
|
)
|
||||||
def main(
|
def main(
|
||||||
|
@ -45,10 +46,13 @@ def main(
|
||||||
dropout=0.5,
|
dropout=0.5,
|
||||||
lr=0.005,
|
lr=0.005,
|
||||||
l2=1e-6,
|
l2=1e-6,
|
||||||
train_inst=None,
|
train_articles=None,
|
||||||
dev_inst=None,
|
dev_articles=None,
|
||||||
labels_discard=None
|
labels_discard=None
|
||||||
):
|
):
|
||||||
|
if not output_dir:
|
||||||
|
logger.warning("No output dir specified so no results will be written, are you sure about this ?")
|
||||||
|
|
||||||
logger.info("Creating Entity Linker with Wikipedia and WikiData")
|
logger.info("Creating Entity Linker with Wikipedia and WikiData")
|
||||||
|
|
||||||
output_dir = Path(output_dir) if output_dir else dir_kb
|
output_dir = Path(output_dir) if output_dir else dir_kb
|
||||||
|
@ -64,44 +68,33 @@ def main(
|
||||||
# STEP 1 : load the NLP object
|
# STEP 1 : load the NLP object
|
||||||
logger.info("STEP 1a: Loading model from {}".format(nlp_dir))
|
logger.info("STEP 1a: Loading model from {}".format(nlp_dir))
|
||||||
nlp = spacy.load(nlp_dir)
|
nlp = spacy.load(nlp_dir)
|
||||||
logger.info("STEP 1b: Loading KB from {}".format(kb_path))
|
logger.info("Original NLP pipeline has following pipeline components: {}".format(nlp.pipe_names))
|
||||||
kb = read_kb(nlp, kb_path)
|
|
||||||
|
|
||||||
# check that there is a NER component in the pipeline
|
# check that there is a NER component in the pipeline
|
||||||
if "ner" not in nlp.pipe_names:
|
if "ner" not in nlp.pipe_names:
|
||||||
raise ValueError("The `nlp` object should have a pretrained `ner` component.")
|
raise ValueError("The `nlp` object should have a pretrained `ner` component.")
|
||||||
|
|
||||||
# STEP 2: read the training dataset previously created from WP
|
logger.info("STEP 1b: Loading KB from {}".format(kb_path))
|
||||||
logger.info("STEP 2: Reading training dataset from {}".format(training_path))
|
kb = read_kb(nlp, kb_path)
|
||||||
|
|
||||||
|
# STEP 2: read the training dataset previously created from WP
|
||||||
|
logger.info("STEP 2: Reading training & dev dataset from {}".format(training_path))
|
||||||
|
train_indices, dev_indices = wikipedia_processor.read_training_indices(training_path)
|
||||||
|
logger.info("Training set has {} articles, limit set to roughly {} articles per epoch"
|
||||||
|
.format(len(train_indices), train_articles if train_articles else "all"))
|
||||||
|
logger.info("Dev set has {} articles, limit set to rougly {} articles for evaluation"
|
||||||
|
.format(len(dev_indices), dev_articles if dev_articles else "all"))
|
||||||
|
if dev_articles:
|
||||||
|
dev_indices = dev_indices[0:dev_articles]
|
||||||
|
|
||||||
|
# STEP 3: create and train an entity linking pipe
|
||||||
|
logger.info("STEP 3: Creating and training an Entity Linking pipe for {} epochs".format(epochs))
|
||||||
if labels_discard:
|
if labels_discard:
|
||||||
labels_discard = [x.strip() for x in labels_discard.split(",")]
|
labels_discard = [x.strip() for x in labels_discard.split(",")]
|
||||||
logger.info("Discarding {} NER types: {}".format(len(labels_discard), labels_discard))
|
logger.info("Discarding {} NER types: {}".format(len(labels_discard), labels_discard))
|
||||||
else:
|
else:
|
||||||
labels_discard = []
|
labels_discard = []
|
||||||
|
|
||||||
train_data = wikipedia_processor.read_training(
|
|
||||||
nlp=nlp,
|
|
||||||
entity_file_path=training_path,
|
|
||||||
dev=False,
|
|
||||||
limit=train_inst,
|
|
||||||
kb=kb,
|
|
||||||
labels_discard=labels_discard
|
|
||||||
)
|
|
||||||
|
|
||||||
# for testing, get all pos instances (independently of KB)
|
|
||||||
dev_data = wikipedia_processor.read_training(
|
|
||||||
nlp=nlp,
|
|
||||||
entity_file_path=training_path,
|
|
||||||
dev=True,
|
|
||||||
limit=dev_inst,
|
|
||||||
kb=None,
|
|
||||||
labels_discard=labels_discard
|
|
||||||
)
|
|
||||||
|
|
||||||
# STEP 3: create and train an entity linking pipe
|
|
||||||
logger.info("STEP 3: Creating and training an Entity Linking pipe")
|
|
||||||
|
|
||||||
el_pipe = nlp.create_pipe(
|
el_pipe = nlp.create_pipe(
|
||||||
name="entity_linker", config={"pretrained_vectors": nlp.vocab.vectors.name,
|
name="entity_linker", config={"pretrained_vectors": nlp.vocab.vectors.name,
|
||||||
"labels_discard": labels_discard}
|
"labels_discard": labels_discard}
|
||||||
|
@ -115,80 +108,65 @@ def main(
|
||||||
optimizer.learn_rate = lr
|
optimizer.learn_rate = lr
|
||||||
optimizer.L2 = l2
|
optimizer.L2 = l2
|
||||||
|
|
||||||
logger.info("Training on {} articles".format(len(train_data)))
|
|
||||||
logger.info("Dev testing on {} articles".format(len(dev_data)))
|
|
||||||
|
|
||||||
# baseline performance on dev data
|
|
||||||
logger.info("Dev Baseline Accuracies:")
|
logger.info("Dev Baseline Accuracies:")
|
||||||
measure_performance(dev_data, kb, el_pipe, baseline=True, context=False)
|
dev_data = wikipedia_processor.read_el_docs_golds(nlp=nlp, entity_file_path=training_path,
|
||||||
|
dev=True, line_ids=dev_indices,
|
||||||
|
kb=kb, labels_discard=labels_discard)
|
||||||
|
|
||||||
|
measure_performance(dev_data, kb, el_pipe, baseline=True, context=False, dev_limit=len(dev_indices))
|
||||||
|
|
||||||
for itn in range(epochs):
|
for itn in range(epochs):
|
||||||
random.shuffle(train_data)
|
random.shuffle(train_indices)
|
||||||
losses = {}
|
losses = {}
|
||||||
batches = minibatch(train_data, size=compounding(4.0, 128.0, 1.001))
|
batches = minibatch(train_indices, size=compounding(8.0, 128.0, 1.001))
|
||||||
batchnr = 0
|
batchnr = 0
|
||||||
|
articles_processed = 0
|
||||||
|
|
||||||
with nlp.disable_pipes(*other_pipes):
|
# we either process the whole training file, or just a part each epoch
|
||||||
|
bar_total = len(train_indices)
|
||||||
|
if train_articles:
|
||||||
|
bar_total = train_articles
|
||||||
|
|
||||||
|
with tqdm(total=bar_total, leave=False, desc='Epoch ' + str(itn)) as pbar:
|
||||||
for batch in batches:
|
for batch in batches:
|
||||||
try:
|
if not train_articles or articles_processed < train_articles:
|
||||||
docs, golds = zip(*batch)
|
with nlp.disable_pipes("entity_linker"):
|
||||||
nlp.update(
|
train_batch = wikipedia_processor.read_el_docs_golds(nlp=nlp, entity_file_path=training_path,
|
||||||
docs=docs,
|
dev=False, line_ids=batch,
|
||||||
golds=golds,
|
kb=kb, labels_discard=labels_discard)
|
||||||
sgd=optimizer,
|
docs, golds = zip(*train_batch)
|
||||||
drop=dropout,
|
try:
|
||||||
losses=losses,
|
with nlp.disable_pipes(*other_pipes):
|
||||||
)
|
nlp.update(
|
||||||
batchnr += 1
|
docs=docs,
|
||||||
except Exception as e:
|
golds=golds,
|
||||||
logger.error("Error updating batch:" + str(e))
|
sgd=optimizer,
|
||||||
|
drop=dropout,
|
||||||
|
losses=losses,
|
||||||
|
)
|
||||||
|
batchnr += 1
|
||||||
|
articles_processed += len(docs)
|
||||||
|
pbar.update(len(docs))
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("Error updating batch:" + str(e))
|
||||||
if batchnr > 0:
|
if batchnr > 0:
|
||||||
logging.info("Epoch {}, train loss {}".format(itn, round(losses["entity_linker"] / batchnr, 2)))
|
logging.info("Epoch {} trained on {} articles, train loss {}"
|
||||||
measure_performance(dev_data, kb, el_pipe, baseline=False, context=True)
|
.format(itn, articles_processed, round(losses["entity_linker"] / batchnr, 2)))
|
||||||
|
# re-read the dev_data (data is returned as a generator)
|
||||||
# STEP 4: measure the performance of our trained pipe on an independent dev set
|
dev_data = wikipedia_processor.read_el_docs_golds(nlp=nlp, entity_file_path=training_path,
|
||||||
logger.info("STEP 4: Final performance measurement of Entity Linking pipe")
|
dev=True, line_ids=dev_indices,
|
||||||
measure_performance(dev_data, kb, el_pipe)
|
kb=kb, labels_discard=labels_discard)
|
||||||
|
measure_performance(dev_data, kb, el_pipe, baseline=False, context=True, dev_limit=len(dev_indices))
|
||||||
# STEP 5: apply the EL pipe on a toy example
|
|
||||||
logger.info("STEP 5: Applying Entity Linking to toy example")
|
|
||||||
run_el_toy_example(nlp=nlp)
|
|
||||||
|
|
||||||
if output_dir:
|
if output_dir:
|
||||||
# STEP 6: write the NLP pipeline (now including an EL model) to file
|
# STEP 4: write the NLP pipeline (now including an EL model) to file
|
||||||
logger.info("STEP 6: Writing trained NLP to {}".format(nlp_output_dir))
|
logger.info("Final NLP pipeline has following pipeline components: {}".format(nlp.pipe_names))
|
||||||
|
logger.info("STEP 4: Writing trained NLP to {}".format(nlp_output_dir))
|
||||||
nlp.to_disk(nlp_output_dir)
|
nlp.to_disk(nlp_output_dir)
|
||||||
|
|
||||||
logger.info("Done!")
|
logger.info("Done!")
|
||||||
|
|
||||||
|
|
||||||
def check_kb(kb):
|
|
||||||
for mention in ("Bush", "Douglas Adams", "Homer", "Brazil", "China"):
|
|
||||||
candidates = kb.get_candidates(mention)
|
|
||||||
|
|
||||||
logger.info("generating candidates for " + mention + " :")
|
|
||||||
for c in candidates:
|
|
||||||
logger.info(" ".join[
|
|
||||||
str(c.prior_prob),
|
|
||||||
c.alias_,
|
|
||||||
"-->",
|
|
||||||
c.entity_ + " (freq=" + str(c.entity_freq) + ")"
|
|
||||||
])
|
|
||||||
|
|
||||||
|
|
||||||
def run_el_toy_example(nlp):
|
|
||||||
text = (
|
|
||||||
"In The Hitchhiker's Guide to the Galaxy, written by Douglas Adams, "
|
|
||||||
"Douglas reminds us to always bring our towel, even in China or Brazil. "
|
|
||||||
"The main character in Doug's novel is the man Arthur Dent, "
|
|
||||||
"but Dougledydoug doesn't write about George Washington or Homer Simpson."
|
|
||||||
)
|
|
||||||
doc = nlp(text)
|
|
||||||
logger.info(text)
|
|
||||||
for ent in doc.ents:
|
|
||||||
logger.info(" ".join(["ent", ent.text, ent.label_, ent.kb_id_]))
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
logging.basicConfig(level=logging.INFO, format=LOG_FORMAT)
|
logging.basicConfig(level=logging.INFO, format=LOG_FORMAT)
|
||||||
plac.call(main)
|
plac.call(main)
|
||||||
|
|
|
@ -6,9 +6,6 @@ import bz2
|
||||||
import logging
|
import logging
|
||||||
import random
|
import random
|
||||||
import json
|
import json
|
||||||
from tqdm import tqdm
|
|
||||||
|
|
||||||
from functools import partial
|
|
||||||
|
|
||||||
from spacy.gold import GoldParse
|
from spacy.gold import GoldParse
|
||||||
from bin.wiki_entity_linking import wiki_io as io
|
from bin.wiki_entity_linking import wiki_io as io
|
||||||
|
@ -454,25 +451,40 @@ def _write_training_entities(outputfile, article_id, clean_text, entities):
|
||||||
outputfile.write(line)
|
outputfile.write(line)
|
||||||
|
|
||||||
|
|
||||||
def read_training(nlp, entity_file_path, dev, limit, kb, labels_discard=None):
|
def read_training_indices(entity_file_path):
|
||||||
""" This method provides training examples that correspond to the entity annotations found by the nlp object.
|
""" This method creates two lists of indices into the training file: one with indices for the
|
||||||
|
training examples, and one for the dev examples."""
|
||||||
|
train_indices = []
|
||||||
|
dev_indices = []
|
||||||
|
|
||||||
|
with entity_file_path.open("r", encoding="utf8") as file:
|
||||||
|
for i, line in enumerate(file):
|
||||||
|
example = json.loads(line)
|
||||||
|
article_id = example["article_id"]
|
||||||
|
clean_text = example["clean_text"]
|
||||||
|
|
||||||
|
if is_valid_article(clean_text):
|
||||||
|
if is_dev(article_id):
|
||||||
|
dev_indices.append(i)
|
||||||
|
else:
|
||||||
|
train_indices.append(i)
|
||||||
|
|
||||||
|
return train_indices, dev_indices
|
||||||
|
|
||||||
|
|
||||||
|
def read_el_docs_golds(nlp, entity_file_path, dev, line_ids, kb, labels_discard=None):
|
||||||
|
""" This method provides training/dev examples that correspond to the entity annotations found by the nlp object.
|
||||||
For training, it will include both positive and negative examples by using the candidate generator from the kb.
|
For training, it will include both positive and negative examples by using the candidate generator from the kb.
|
||||||
For testing (kb=None), it will include all positive examples only."""
|
For testing (kb=None), it will include all positive examples only."""
|
||||||
if not labels_discard:
|
if not labels_discard:
|
||||||
labels_discard = []
|
labels_discard = []
|
||||||
|
|
||||||
data = []
|
texts = []
|
||||||
num_entities = 0
|
entities_list = []
|
||||||
get_gold_parse = partial(
|
|
||||||
_get_gold_parse, dev=dev, kb=kb, labels_discard=labels_discard
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
"Reading {} data with limit {}".format("dev" if dev else "train", limit)
|
|
||||||
)
|
|
||||||
with entity_file_path.open("r", encoding="utf8") as file:
|
with entity_file_path.open("r", encoding="utf8") as file:
|
||||||
with tqdm(total=limit, leave=False) as pbar:
|
for i, line in enumerate(file):
|
||||||
for i, line in enumerate(file):
|
if i in line_ids:
|
||||||
example = json.loads(line)
|
example = json.loads(line)
|
||||||
article_id = example["article_id"]
|
article_id = example["article_id"]
|
||||||
clean_text = example["clean_text"]
|
clean_text = example["clean_text"]
|
||||||
|
@ -481,16 +493,15 @@ def read_training(nlp, entity_file_path, dev, limit, kb, labels_discard=None):
|
||||||
if dev != is_dev(article_id) or not is_valid_article(clean_text):
|
if dev != is_dev(article_id) or not is_valid_article(clean_text):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
doc = nlp(clean_text)
|
texts.append(clean_text)
|
||||||
gold = get_gold_parse(doc, entities)
|
entities_list.append(entities)
|
||||||
if gold and len(gold.links) > 0:
|
|
||||||
data.append((doc, gold))
|
docs = nlp.pipe(texts, batch_size=50)
|
||||||
num_entities += len(gold.links)
|
|
||||||
pbar.update(len(gold.links))
|
for doc, entities in zip(docs, entities_list):
|
||||||
if limit and num_entities >= limit:
|
gold = _get_gold_parse(doc, entities, dev=dev, kb=kb, labels_discard=labels_discard)
|
||||||
break
|
if gold and len(gold.links) > 0:
|
||||||
logger.info("Read {} entities in {} articles".format(num_entities, len(data)))
|
yield doc, gold
|
||||||
return data
|
|
||||||
|
|
||||||
|
|
||||||
def _get_gold_parse(doc, entities, dev, kb, labels_discard):
|
def _get_gold_parse(doc, entities, dev, kb, labels_discard):
|
||||||
|
|
|
@ -1308,7 +1308,7 @@ class EntityLinker(Pipe):
|
||||||
for i, doc in enumerate(docs):
|
for i, doc in enumerate(docs):
|
||||||
if len(doc) > 0:
|
if len(doc) > 0:
|
||||||
# Looping through each sentence and each entity
|
# Looping through each sentence and each entity
|
||||||
# This may go wrong if there are entities across sentences - because they might not get a KB ID
|
# This may go wrong if there are entities across sentences - which shouldn't happen normally.
|
||||||
for sent in doc.sents:
|
for sent in doc.sents:
|
||||||
sent_doc = sent.as_doc()
|
sent_doc = sent.as_doc()
|
||||||
# currently, the context is the same for each entity in a sentence (should be refined)
|
# currently, the context is the same for each entity in a sentence (should be refined)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user