Reduce mem usage in training Entity Linker (#4811)

* move nlp processing for el pipe to batch training instead of preprocessing

* adding dev eval back in, and limit in articles instead of entities

* use pipe whenever possible

* few more small doc changes

* access dev data through generator

* tqdm description

* small fixes

* update documentation
This commit is contained in:
Sofie Van Landeghem 2020-01-06 14:59:50 +01:00 committed by Matthew Honnibal
parent 6e9b61b49d
commit 7b96a5e10f
5 changed files with 200 additions and 230 deletions

View File

@ -7,16 +7,16 @@ Run `wikipedia_pretrain_kb.py`
* WikiData: get `latest-all.json.bz2` from https://dumps.wikimedia.org/wikidatawiki/entities/ * WikiData: get `latest-all.json.bz2` from https://dumps.wikimedia.org/wikidatawiki/entities/
* Wikipedia: get `enwiki-latest-pages-articles-multistream.xml.bz2` from https://dumps.wikimedia.org/enwiki/latest/ (or for any other language) * Wikipedia: get `enwiki-latest-pages-articles-multistream.xml.bz2` from https://dumps.wikimedia.org/enwiki/latest/ (or for any other language)
* You can set the filtering parameters for KB construction: * You can set the filtering parameters for KB construction:
* `max_per_alias`: (max) number of candidate entities in the KB per alias/synonym * `max_per_alias` (`-a`): (max) number of candidate entities in the KB per alias/synonym
* `min_freq`: threshold of number of times an entity should occur in the corpus to be included in the KB * `min_freq` (`-f`): threshold of number of times an entity should occur in the corpus to be included in the KB
* `min_pair`: threshold of number of times an entity+alias combination should occur in the corpus to be included in the KB * `min_pair` (`-c`): threshold of number of times an entity+alias combination should occur in the corpus to be included in the KB
* Further parameters to set: * Further parameters to set:
* `descriptions_from_wikipedia`: whether to parse descriptions from Wikipedia (`True`) or Wikidata (`False`) * `descriptions_from_wikipedia` (`-wp`): whether to parse descriptions from Wikipedia (`True`) or Wikidata (`False`)
* `entity_vector_length`: length of the pre-trained entity description vectors * `entity_vector_length` (`-v`): length of the pre-trained entity description vectors
* `lang`: language for which to fetch Wikidata information (as the dump contains all languages) * `lang` (`-la`): language for which to fetch Wikidata information (as the dump contains all languages)
Quick testing and rerunning: Quick testing and rerunning:
* When trying out the pipeline for a quick test, set `limit_prior`, `limit_train` and/or `limit_wd` to read only parts of the dumps instead of everything. * When trying out the pipeline for a quick test, set `limit_prior` (`-lp`), `limit_train` (`-lt`) and/or `limit_wd` (`-lw`) to read only parts of the dumps instead of everything.
* If you only want to (re)run certain parts of the pipeline, just remove the corresponding files and they will be recalculated or reparsed. * If you only want to (re)run certain parts of the pipeline, just remove the corresponding files and they will be recalculated or reparsed.
@ -24,11 +24,13 @@ Quick testing and rerunning:
Run `wikidata_train_entity_linker.py` Run `wikidata_train_entity_linker.py`
* This takes the **KB directory** produced by Step 1, and trains an **Entity Linking model** * This takes the **KB directory** produced by Step 1, and trains an **Entity Linking model**
* Specify the output directory (`-o`) in which the final, trained model will be saved
* You can set the learning parameters for the EL training: * You can set the learning parameters for the EL training:
* `epochs`: number of training iterations * `epochs` (`-e`): number of training iterations
* `dropout`: dropout rate * `dropout` (`-p`): dropout rate
* `lr`: learning rate * `lr` (`-n`): learning rate
* `l2`: L2 regularization * `l2` (`-r`): L2 regularization
* Specify the number of training and dev testing entities with `train_inst` and `dev_inst` respectively * Specify the number of training and dev testing articles with `train_articles` (`-t`) and `dev_articles` (`-d`) respectively
* If not specified, the full dataset will be processed - this may take a LONG time !
* Further parameters to set: * Further parameters to set:
* `labels_discard`: NER label types to discard during training * `labels_discard` (`-l`): NER label types to discard during training

View File

@ -1,6 +1,8 @@
# coding: utf-8
from __future__ import unicode_literals
import logging import logging
import random import random
from tqdm import tqdm from tqdm import tqdm
from collections import defaultdict from collections import defaultdict
@ -92,133 +94,110 @@ class BaselineResults(object):
self.random.update_metrics(ent_label, true_entity, random_candidate) self.random.update_metrics(ent_label, true_entity, random_candidate)
def measure_performance(dev_data, kb, el_pipe, baseline=True, context=True): def measure_performance(dev_data, kb, el_pipe, baseline=True, context=True, dev_limit=None):
if baseline: counts = dict()
baseline_accuracies, counts = measure_baselines(dev_data, kb) baseline_results = BaselineResults()
logger.info("Counts: {}".format({k: v for k, v in sorted(counts.items())})) context_results = EvaluationResults()
logger.info(baseline_accuracies.report_performance("random")) combo_results = EvaluationResults()
logger.info(baseline_accuracies.report_performance("prior"))
logger.info(baseline_accuracies.report_performance("oracle"))
if context: for doc, gold in tqdm(dev_data, total=dev_limit, leave=False, desc='Processing dev data'):
# using only context if len(doc) > 0:
el_pipe.cfg["incl_context"] = True correct_ents = dict()
el_pipe.cfg["incl_prior"] = False
results = get_eval_results(dev_data, el_pipe)
logger.info(results.report_metrics("context only"))
# measuring combined accuracy (prior + context)
el_pipe.cfg["incl_context"] = True
el_pipe.cfg["incl_prior"] = True
results = get_eval_results(dev_data, el_pipe)
logger.info(results.report_metrics("context and prior"))
def get_eval_results(data, el_pipe=None):
"""
Evaluate the ent.kb_id_ annotations against the gold standard.
Only evaluate entities that overlap between gold and NER, to isolate the performance of the NEL.
If the docs in the data require further processing with an entity linker, set el_pipe.
"""
docs = []
golds = []
for d, g in tqdm(data, leave=False):
if len(d) > 0:
golds.append(g)
if el_pipe is not None:
docs.append(el_pipe(d))
else:
docs.append(d)
results = EvaluationResults()
for doc, gold in zip(docs, golds):
try:
correct_entries_per_article = dict()
for entity, kb_dict in gold.links.items(): for entity, kb_dict in gold.links.items():
start, end = entity start, end = entity
for gold_kb, value in kb_dict.items(): for gold_kb, value in kb_dict.items():
if value: if value:
# only evaluating on positive examples # only evaluating on positive examples
offset = _offset(start, end) offset = _offset(start, end)
correct_entries_per_article[offset] = gold_kb correct_ents[offset] = gold_kb
for ent in doc.ents: if baseline:
ent_label = ent.label_ _add_baseline(baseline_results, counts, doc, correct_ents, kb)
pred_entity = ent.kb_id_
start = ent.start_char
end = ent.end_char
offset = _offset(start, end)
gold_entity = correct_entries_per_article.get(offset, None)
# the gold annotations are not complete so we can't evaluate missing annotations as 'wrong'
if gold_entity is not None:
results.update_metrics(ent_label, gold_entity, pred_entity)
except Exception as e: if context:
logging.error("Error assessing accuracy " + str(e)) # using only context
el_pipe.cfg["incl_context"] = True
el_pipe.cfg["incl_prior"] = False
_add_eval_result(context_results, doc, correct_ents, el_pipe)
return results # measuring combined accuracy (prior + context)
el_pipe.cfg["incl_context"] = True
el_pipe.cfg["incl_prior"] = True
_add_eval_result(combo_results, doc, correct_ents, el_pipe)
if baseline:
logger.info("Counts: {}".format({k: v for k, v in sorted(counts.items())}))
logger.info(baseline_results.report_performance("random"))
logger.info(baseline_results.report_performance("prior"))
logger.info(baseline_results.report_performance("oracle"))
if context:
logger.info(context_results.report_metrics("context only"))
logger.info(combo_results.report_metrics("context and prior"))
def measure_baselines(data, kb): def _add_eval_result(results, doc, correct_ents, el_pipe):
""" """
Measure 3 performance baselines: random selection, prior probabilities, and 'oracle' prediction for upper bound. Evaluate the ent.kb_id_ annotations against the gold standard.
Only evaluate entities that overlap between gold and NER, to isolate the performance of the NEL. Only evaluate entities that overlap between gold and NER, to isolate the performance of the NEL.
Also return a dictionary of counts by entity label.
""" """
counts_d = dict() try:
doc = el_pipe(doc)
baseline_results = BaselineResults()
docs = [d for d, g in data if len(d) > 0]
golds = [g for d, g in data if len(d) > 0]
for doc, gold in zip(docs, golds):
correct_entries_per_article = dict()
for entity, kb_dict in gold.links.items():
start, end = entity
for gold_kb, value in kb_dict.items():
# only evaluating on positive examples
if value:
offset = _offset(start, end)
correct_entries_per_article[offset] = gold_kb
for ent in doc.ents: for ent in doc.ents:
ent_label = ent.label_ ent_label = ent.label_
start = ent.start_char start = ent.start_char
end = ent.end_char end = ent.end_char
offset = _offset(start, end) offset = _offset(start, end)
gold_entity = correct_entries_per_article.get(offset, None) gold_entity = correct_ents.get(offset, None)
# the gold annotations are not complete so we can't evaluate missing annotations as 'wrong' # the gold annotations are not complete so we can't evaluate missing annotations as 'wrong'
if gold_entity is not None: if gold_entity is not None:
candidates = kb.get_candidates(ent.text) pred_entity = ent.kb_id_
oracle_candidate = "" results.update_metrics(ent_label, gold_entity, pred_entity)
prior_candidate = ""
random_candidate = ""
if candidates:
scores = []
for c in candidates: except Exception as e:
scores.append(c.prior_prob) logging.error("Error assessing accuracy " + str(e))
if c.entity_ == gold_entity:
oracle_candidate = c.entity_
best_index = scores.index(max(scores))
prior_candidate = candidates[best_index].entity_
random_candidate = random.choice(candidates).entity_
current_count = counts_d.get(ent_label, 0) def _add_baseline(baseline_results, counts, doc, correct_ents, kb):
counts_d[ent_label] = current_count+1 """
Measure 3 performance baselines: random selection, prior probabilities, and 'oracle' prediction for upper bound.
Only evaluate entities that overlap between gold and NER, to isolate the performance of the NEL.
"""
for ent in doc.ents:
ent_label = ent.label_
start = ent.start_char
end = ent.end_char
offset = _offset(start, end)
gold_entity = correct_ents.get(offset, None)
baseline_results.update_baselines( # the gold annotations are not complete so we can't evaluate missing annotations as 'wrong'
gold_entity, if gold_entity is not None:
ent_label, candidates = kb.get_candidates(ent.text)
random_candidate, oracle_candidate = ""
prior_candidate, prior_candidate = ""
oracle_candidate, random_candidate = ""
) if candidates:
scores = []
return baseline_results, counts_d for c in candidates:
scores.append(c.prior_prob)
if c.entity_ == gold_entity:
oracle_candidate = c.entity_
best_index = scores.index(max(scores))
prior_candidate = candidates[best_index].entity_
random_candidate = random.choice(candidates).entity_
current_count = counts.get(ent_label, 0)
counts[ent_label] = current_count+1
baseline_results.update_baselines(
gold_entity,
ent_label,
random_candidate,
prior_candidate,
oracle_candidate,
)
def _offset(start, end): def _offset(start, end):

View File

@ -1,5 +1,5 @@
# coding: utf-8 # coding: utf-8
"""Script to take a previously created Knowledge Base and train an entity linking """Script that takes a previously created Knowledge Base and trains an entity linking
pipeline. The provided KB directory should hold the kb, the original nlp object and pipeline. The provided KB directory should hold the kb, the original nlp object and
its vocab used to create the KB, and a few auxiliary files such as the entity definitions, its vocab used to create the KB, and a few auxiliary files such as the entity definitions,
as created by the script `wikidata_create_kb`. as created by the script `wikidata_create_kb`.
@ -14,6 +14,7 @@ import logging
import spacy import spacy
from pathlib import Path from pathlib import Path
import plac import plac
from tqdm import tqdm
from bin.wiki_entity_linking import wikipedia_processor from bin.wiki_entity_linking import wikipedia_processor
from bin.wiki_entity_linking import TRAINING_DATA_FILE, KB_MODEL_DIR, KB_FILE, LOG_FORMAT, OUTPUT_MODEL_DIR from bin.wiki_entity_linking import TRAINING_DATA_FILE, KB_MODEL_DIR, KB_FILE, LOG_FORMAT, OUTPUT_MODEL_DIR
@ -33,8 +34,8 @@ logger = logging.getLogger(__name__)
dropout=("Dropout to prevent overfitting (default 0.5)", "option", "p", float), dropout=("Dropout to prevent overfitting (default 0.5)", "option", "p", float),
lr=("Learning rate (default 0.005)", "option", "n", float), lr=("Learning rate (default 0.005)", "option", "n", float),
l2=("L2 regularization", "option", "r", float), l2=("L2 regularization", "option", "r", float),
train_inst=("# training instances (default 90% of all)", "option", "t", int), train_articles=("# training articles (default 90% of all)", "option", "t", int),
dev_inst=("# test instances (default 10% of all)", "option", "d", int), dev_articles=("# dev test articles (default 10% of all)", "option", "d", int),
labels_discard=("NER labels to discard (default None)", "option", "l", str), labels_discard=("NER labels to discard (default None)", "option", "l", str),
) )
def main( def main(
@ -45,10 +46,13 @@ def main(
dropout=0.5, dropout=0.5,
lr=0.005, lr=0.005,
l2=1e-6, l2=1e-6,
train_inst=None, train_articles=None,
dev_inst=None, dev_articles=None,
labels_discard=None labels_discard=None
): ):
if not output_dir:
logger.warning("No output dir specified so no results will be written, are you sure about this ?")
logger.info("Creating Entity Linker with Wikipedia and WikiData") logger.info("Creating Entity Linker with Wikipedia and WikiData")
output_dir = Path(output_dir) if output_dir else dir_kb output_dir = Path(output_dir) if output_dir else dir_kb
@ -64,44 +68,33 @@ def main(
# STEP 1 : load the NLP object # STEP 1 : load the NLP object
logger.info("STEP 1a: Loading model from {}".format(nlp_dir)) logger.info("STEP 1a: Loading model from {}".format(nlp_dir))
nlp = spacy.load(nlp_dir) nlp = spacy.load(nlp_dir)
logger.info("STEP 1b: Loading KB from {}".format(kb_path)) logger.info("Original NLP pipeline has following pipeline components: {}".format(nlp.pipe_names))
kb = read_kb(nlp, kb_path)
# check that there is a NER component in the pipeline # check that there is a NER component in the pipeline
if "ner" not in nlp.pipe_names: if "ner" not in nlp.pipe_names:
raise ValueError("The `nlp` object should have a pretrained `ner` component.") raise ValueError("The `nlp` object should have a pretrained `ner` component.")
# STEP 2: read the training dataset previously created from WP logger.info("STEP 1b: Loading KB from {}".format(kb_path))
logger.info("STEP 2: Reading training dataset from {}".format(training_path)) kb = read_kb(nlp, kb_path)
# STEP 2: read the training dataset previously created from WP
logger.info("STEP 2: Reading training & dev dataset from {}".format(training_path))
train_indices, dev_indices = wikipedia_processor.read_training_indices(training_path)
logger.info("Training set has {} articles, limit set to roughly {} articles per epoch"
.format(len(train_indices), train_articles if train_articles else "all"))
logger.info("Dev set has {} articles, limit set to rougly {} articles for evaluation"
.format(len(dev_indices), dev_articles if dev_articles else "all"))
if dev_articles:
dev_indices = dev_indices[0:dev_articles]
# STEP 3: create and train an entity linking pipe
logger.info("STEP 3: Creating and training an Entity Linking pipe for {} epochs".format(epochs))
if labels_discard: if labels_discard:
labels_discard = [x.strip() for x in labels_discard.split(",")] labels_discard = [x.strip() for x in labels_discard.split(",")]
logger.info("Discarding {} NER types: {}".format(len(labels_discard), labels_discard)) logger.info("Discarding {} NER types: {}".format(len(labels_discard), labels_discard))
else: else:
labels_discard = [] labels_discard = []
train_data = wikipedia_processor.read_training(
nlp=nlp,
entity_file_path=training_path,
dev=False,
limit=train_inst,
kb=kb,
labels_discard=labels_discard
)
# for testing, get all pos instances (independently of KB)
dev_data = wikipedia_processor.read_training(
nlp=nlp,
entity_file_path=training_path,
dev=True,
limit=dev_inst,
kb=None,
labels_discard=labels_discard
)
# STEP 3: create and train an entity linking pipe
logger.info("STEP 3: Creating and training an Entity Linking pipe")
el_pipe = nlp.create_pipe( el_pipe = nlp.create_pipe(
name="entity_linker", config={"pretrained_vectors": nlp.vocab.vectors.name, name="entity_linker", config={"pretrained_vectors": nlp.vocab.vectors.name,
"labels_discard": labels_discard} "labels_discard": labels_discard}
@ -115,80 +108,65 @@ def main(
optimizer.learn_rate = lr optimizer.learn_rate = lr
optimizer.L2 = l2 optimizer.L2 = l2
logger.info("Training on {} articles".format(len(train_data)))
logger.info("Dev testing on {} articles".format(len(dev_data)))
# baseline performance on dev data
logger.info("Dev Baseline Accuracies:") logger.info("Dev Baseline Accuracies:")
measure_performance(dev_data, kb, el_pipe, baseline=True, context=False) dev_data = wikipedia_processor.read_el_docs_golds(nlp=nlp, entity_file_path=training_path,
dev=True, line_ids=dev_indices,
kb=kb, labels_discard=labels_discard)
measure_performance(dev_data, kb, el_pipe, baseline=True, context=False, dev_limit=len(dev_indices))
for itn in range(epochs): for itn in range(epochs):
random.shuffle(train_data) random.shuffle(train_indices)
losses = {} losses = {}
batches = minibatch(train_data, size=compounding(4.0, 128.0, 1.001)) batches = minibatch(train_indices, size=compounding(8.0, 128.0, 1.001))
batchnr = 0 batchnr = 0
articles_processed = 0
with nlp.disable_pipes(*other_pipes): # we either process the whole training file, or just a part each epoch
bar_total = len(train_indices)
if train_articles:
bar_total = train_articles
with tqdm(total=bar_total, leave=False, desc='Epoch ' + str(itn)) as pbar:
for batch in batches: for batch in batches:
try: if not train_articles or articles_processed < train_articles:
docs, golds = zip(*batch) with nlp.disable_pipes("entity_linker"):
nlp.update( train_batch = wikipedia_processor.read_el_docs_golds(nlp=nlp, entity_file_path=training_path,
docs=docs, dev=False, line_ids=batch,
golds=golds, kb=kb, labels_discard=labels_discard)
sgd=optimizer, docs, golds = zip(*train_batch)
drop=dropout, try:
losses=losses, with nlp.disable_pipes(*other_pipes):
) nlp.update(
batchnr += 1 docs=docs,
except Exception as e: golds=golds,
logger.error("Error updating batch:" + str(e)) sgd=optimizer,
drop=dropout,
losses=losses,
)
batchnr += 1
articles_processed += len(docs)
pbar.update(len(docs))
except Exception as e:
logger.error("Error updating batch:" + str(e))
if batchnr > 0: if batchnr > 0:
logging.info("Epoch {}, train loss {}".format(itn, round(losses["entity_linker"] / batchnr, 2))) logging.info("Epoch {} trained on {} articles, train loss {}"
measure_performance(dev_data, kb, el_pipe, baseline=False, context=True) .format(itn, articles_processed, round(losses["entity_linker"] / batchnr, 2)))
# re-read the dev_data (data is returned as a generator)
# STEP 4: measure the performance of our trained pipe on an independent dev set dev_data = wikipedia_processor.read_el_docs_golds(nlp=nlp, entity_file_path=training_path,
logger.info("STEP 4: Final performance measurement of Entity Linking pipe") dev=True, line_ids=dev_indices,
measure_performance(dev_data, kb, el_pipe) kb=kb, labels_discard=labels_discard)
measure_performance(dev_data, kb, el_pipe, baseline=False, context=True, dev_limit=len(dev_indices))
# STEP 5: apply the EL pipe on a toy example
logger.info("STEP 5: Applying Entity Linking to toy example")
run_el_toy_example(nlp=nlp)
if output_dir: if output_dir:
# STEP 6: write the NLP pipeline (now including an EL model) to file # STEP 4: write the NLP pipeline (now including an EL model) to file
logger.info("STEP 6: Writing trained NLP to {}".format(nlp_output_dir)) logger.info("Final NLP pipeline has following pipeline components: {}".format(nlp.pipe_names))
logger.info("STEP 4: Writing trained NLP to {}".format(nlp_output_dir))
nlp.to_disk(nlp_output_dir) nlp.to_disk(nlp_output_dir)
logger.info("Done!") logger.info("Done!")
def check_kb(kb):
for mention in ("Bush", "Douglas Adams", "Homer", "Brazil", "China"):
candidates = kb.get_candidates(mention)
logger.info("generating candidates for " + mention + " :")
for c in candidates:
logger.info(" ".join[
str(c.prior_prob),
c.alias_,
"-->",
c.entity_ + " (freq=" + str(c.entity_freq) + ")"
])
def run_el_toy_example(nlp):
text = (
"In The Hitchhiker's Guide to the Galaxy, written by Douglas Adams, "
"Douglas reminds us to always bring our towel, even in China or Brazil. "
"The main character in Doug's novel is the man Arthur Dent, "
"but Dougledydoug doesn't write about George Washington or Homer Simpson."
)
doc = nlp(text)
logger.info(text)
for ent in doc.ents:
logger.info(" ".join(["ent", ent.text, ent.label_, ent.kb_id_]))
if __name__ == "__main__": if __name__ == "__main__":
logging.basicConfig(level=logging.INFO, format=LOG_FORMAT) logging.basicConfig(level=logging.INFO, format=LOG_FORMAT)
plac.call(main) plac.call(main)

View File

@ -6,9 +6,6 @@ import bz2
import logging import logging
import random import random
import json import json
from tqdm import tqdm
from functools import partial
from spacy.gold import GoldParse from spacy.gold import GoldParse
from bin.wiki_entity_linking import wiki_io as io from bin.wiki_entity_linking import wiki_io as io
@ -454,25 +451,40 @@ def _write_training_entities(outputfile, article_id, clean_text, entities):
outputfile.write(line) outputfile.write(line)
def read_training(nlp, entity_file_path, dev, limit, kb, labels_discard=None): def read_training_indices(entity_file_path):
""" This method provides training examples that correspond to the entity annotations found by the nlp object. """ This method creates two lists of indices into the training file: one with indices for the
training examples, and one for the dev examples."""
train_indices = []
dev_indices = []
with entity_file_path.open("r", encoding="utf8") as file:
for i, line in enumerate(file):
example = json.loads(line)
article_id = example["article_id"]
clean_text = example["clean_text"]
if is_valid_article(clean_text):
if is_dev(article_id):
dev_indices.append(i)
else:
train_indices.append(i)
return train_indices, dev_indices
def read_el_docs_golds(nlp, entity_file_path, dev, line_ids, kb, labels_discard=None):
""" This method provides training/dev examples that correspond to the entity annotations found by the nlp object.
For training, it will include both positive and negative examples by using the candidate generator from the kb. For training, it will include both positive and negative examples by using the candidate generator from the kb.
For testing (kb=None), it will include all positive examples only.""" For testing (kb=None), it will include all positive examples only."""
if not labels_discard: if not labels_discard:
labels_discard = [] labels_discard = []
data = [] texts = []
num_entities = 0 entities_list = []
get_gold_parse = partial(
_get_gold_parse, dev=dev, kb=kb, labels_discard=labels_discard
)
logger.info(
"Reading {} data with limit {}".format("dev" if dev else "train", limit)
)
with entity_file_path.open("r", encoding="utf8") as file: with entity_file_path.open("r", encoding="utf8") as file:
with tqdm(total=limit, leave=False) as pbar: for i, line in enumerate(file):
for i, line in enumerate(file): if i in line_ids:
example = json.loads(line) example = json.loads(line)
article_id = example["article_id"] article_id = example["article_id"]
clean_text = example["clean_text"] clean_text = example["clean_text"]
@ -481,16 +493,15 @@ def read_training(nlp, entity_file_path, dev, limit, kb, labels_discard=None):
if dev != is_dev(article_id) or not is_valid_article(clean_text): if dev != is_dev(article_id) or not is_valid_article(clean_text):
continue continue
doc = nlp(clean_text) texts.append(clean_text)
gold = get_gold_parse(doc, entities) entities_list.append(entities)
if gold and len(gold.links) > 0:
data.append((doc, gold)) docs = nlp.pipe(texts, batch_size=50)
num_entities += len(gold.links)
pbar.update(len(gold.links)) for doc, entities in zip(docs, entities_list):
if limit and num_entities >= limit: gold = _get_gold_parse(doc, entities, dev=dev, kb=kb, labels_discard=labels_discard)
break if gold and len(gold.links) > 0:
logger.info("Read {} entities in {} articles".format(num_entities, len(data))) yield doc, gold
return data
def _get_gold_parse(doc, entities, dev, kb, labels_discard): def _get_gold_parse(doc, entities, dev, kb, labels_discard):

View File

@ -1308,7 +1308,7 @@ class EntityLinker(Pipe):
for i, doc in enumerate(docs): for i, doc in enumerate(docs):
if len(doc) > 0: if len(doc) > 0:
# Looping through each sentence and each entity # Looping through each sentence and each entity
# This may go wrong if there are entities across sentences - because they might not get a KB ID # This may go wrong if there are entities across sentences - which shouldn't happen normally.
for sent in doc.sents: for sent in doc.sents:
sent_doc = sent.as_doc() sent_doc = sent.as_doc()
# currently, the context is the same for each entity in a sentence (should be refined) # currently, the context is the same for each entity in a sentence (should be refined)