mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-27 01:34:30 +03:00
Reduce mem usage in training Entity Linker (#4811)
* move nlp processing for el pipe to batch training instead of preprocessing * adding dev eval back in, and limit in articles instead of entities * use pipe whenever possible * few more small doc changes * access dev data through generator * tqdm description * small fixes * update documentation
This commit is contained in:
parent
6e9b61b49d
commit
7b96a5e10f
|
@ -7,16 +7,16 @@ Run `wikipedia_pretrain_kb.py`
|
|||
* WikiData: get `latest-all.json.bz2` from https://dumps.wikimedia.org/wikidatawiki/entities/
|
||||
* Wikipedia: get `enwiki-latest-pages-articles-multistream.xml.bz2` from https://dumps.wikimedia.org/enwiki/latest/ (or for any other language)
|
||||
* You can set the filtering parameters for KB construction:
|
||||
* `max_per_alias`: (max) number of candidate entities in the KB per alias/synonym
|
||||
* `min_freq`: threshold of number of times an entity should occur in the corpus to be included in the KB
|
||||
* `min_pair`: threshold of number of times an entity+alias combination should occur in the corpus to be included in the KB
|
||||
* `max_per_alias` (`-a`): (max) number of candidate entities in the KB per alias/synonym
|
||||
* `min_freq` (`-f`): threshold of number of times an entity should occur in the corpus to be included in the KB
|
||||
* `min_pair` (`-c`): threshold of number of times an entity+alias combination should occur in the corpus to be included in the KB
|
||||
* Further parameters to set:
|
||||
* `descriptions_from_wikipedia`: whether to parse descriptions from Wikipedia (`True`) or Wikidata (`False`)
|
||||
* `entity_vector_length`: length of the pre-trained entity description vectors
|
||||
* `lang`: language for which to fetch Wikidata information (as the dump contains all languages)
|
||||
* `descriptions_from_wikipedia` (`-wp`): whether to parse descriptions from Wikipedia (`True`) or Wikidata (`False`)
|
||||
* `entity_vector_length` (`-v`): length of the pre-trained entity description vectors
|
||||
* `lang` (`-la`): language for which to fetch Wikidata information (as the dump contains all languages)
|
||||
|
||||
Quick testing and rerunning:
|
||||
* When trying out the pipeline for a quick test, set `limit_prior`, `limit_train` and/or `limit_wd` to read only parts of the dumps instead of everything.
|
||||
* When trying out the pipeline for a quick test, set `limit_prior` (`-lp`), `limit_train` (`-lt`) and/or `limit_wd` (`-lw`) to read only parts of the dumps instead of everything.
|
||||
* If you only want to (re)run certain parts of the pipeline, just remove the corresponding files and they will be recalculated or reparsed.
|
||||
|
||||
|
||||
|
@ -24,11 +24,13 @@ Quick testing and rerunning:
|
|||
|
||||
Run `wikidata_train_entity_linker.py`
|
||||
* This takes the **KB directory** produced by Step 1, and trains an **Entity Linking model**
|
||||
* Specify the output directory (`-o`) in which the final, trained model will be saved
|
||||
* You can set the learning parameters for the EL training:
|
||||
* `epochs`: number of training iterations
|
||||
* `dropout`: dropout rate
|
||||
* `lr`: learning rate
|
||||
* `l2`: L2 regularization
|
||||
* Specify the number of training and dev testing entities with `train_inst` and `dev_inst` respectively
|
||||
* `epochs` (`-e`): number of training iterations
|
||||
* `dropout` (`-p`): dropout rate
|
||||
* `lr` (`-n`): learning rate
|
||||
* `l2` (`-r`): L2 regularization
|
||||
* Specify the number of training and dev testing articles with `train_articles` (`-t`) and `dev_articles` (`-d`) respectively
|
||||
* If not specified, the full dataset will be processed - this may take a LONG time !
|
||||
* Further parameters to set:
|
||||
* `labels_discard`: NER label types to discard during training
|
||||
* `labels_discard` (`-l`): NER label types to discard during training
|
||||
|
|
|
@ -1,6 +1,8 @@
|
|||
# coding: utf-8
|
||||
from __future__ import unicode_literals
|
||||
|
||||
import logging
|
||||
import random
|
||||
|
||||
from tqdm import tqdm
|
||||
from collections import defaultdict
|
||||
|
||||
|
@ -92,133 +94,110 @@ class BaselineResults(object):
|
|||
self.random.update_metrics(ent_label, true_entity, random_candidate)
|
||||
|
||||
|
||||
def measure_performance(dev_data, kb, el_pipe, baseline=True, context=True):
|
||||
if baseline:
|
||||
baseline_accuracies, counts = measure_baselines(dev_data, kb)
|
||||
logger.info("Counts: {}".format({k: v for k, v in sorted(counts.items())}))
|
||||
logger.info(baseline_accuracies.report_performance("random"))
|
||||
logger.info(baseline_accuracies.report_performance("prior"))
|
||||
logger.info(baseline_accuracies.report_performance("oracle"))
|
||||
def measure_performance(dev_data, kb, el_pipe, baseline=True, context=True, dev_limit=None):
|
||||
counts = dict()
|
||||
baseline_results = BaselineResults()
|
||||
context_results = EvaluationResults()
|
||||
combo_results = EvaluationResults()
|
||||
|
||||
if context:
|
||||
# using only context
|
||||
el_pipe.cfg["incl_context"] = True
|
||||
el_pipe.cfg["incl_prior"] = False
|
||||
results = get_eval_results(dev_data, el_pipe)
|
||||
logger.info(results.report_metrics("context only"))
|
||||
|
||||
# measuring combined accuracy (prior + context)
|
||||
el_pipe.cfg["incl_context"] = True
|
||||
el_pipe.cfg["incl_prior"] = True
|
||||
results = get_eval_results(dev_data, el_pipe)
|
||||
logger.info(results.report_metrics("context and prior"))
|
||||
|
||||
|
||||
def get_eval_results(data, el_pipe=None):
|
||||
"""
|
||||
Evaluate the ent.kb_id_ annotations against the gold standard.
|
||||
Only evaluate entities that overlap between gold and NER, to isolate the performance of the NEL.
|
||||
If the docs in the data require further processing with an entity linker, set el_pipe.
|
||||
"""
|
||||
docs = []
|
||||
golds = []
|
||||
for d, g in tqdm(data, leave=False):
|
||||
if len(d) > 0:
|
||||
golds.append(g)
|
||||
if el_pipe is not None:
|
||||
docs.append(el_pipe(d))
|
||||
else:
|
||||
docs.append(d)
|
||||
|
||||
results = EvaluationResults()
|
||||
for doc, gold in zip(docs, golds):
|
||||
try:
|
||||
correct_entries_per_article = dict()
|
||||
for doc, gold in tqdm(dev_data, total=dev_limit, leave=False, desc='Processing dev data'):
|
||||
if len(doc) > 0:
|
||||
correct_ents = dict()
|
||||
for entity, kb_dict in gold.links.items():
|
||||
start, end = entity
|
||||
for gold_kb, value in kb_dict.items():
|
||||
if value:
|
||||
# only evaluating on positive examples
|
||||
offset = _offset(start, end)
|
||||
correct_entries_per_article[offset] = gold_kb
|
||||
correct_ents[offset] = gold_kb
|
||||
|
||||
for ent in doc.ents:
|
||||
ent_label = ent.label_
|
||||
pred_entity = ent.kb_id_
|
||||
start = ent.start_char
|
||||
end = ent.end_char
|
||||
offset = _offset(start, end)
|
||||
gold_entity = correct_entries_per_article.get(offset, None)
|
||||
# the gold annotations are not complete so we can't evaluate missing annotations as 'wrong'
|
||||
if gold_entity is not None:
|
||||
results.update_metrics(ent_label, gold_entity, pred_entity)
|
||||
if baseline:
|
||||
_add_baseline(baseline_results, counts, doc, correct_ents, kb)
|
||||
|
||||
except Exception as e:
|
||||
logging.error("Error assessing accuracy " + str(e))
|
||||
if context:
|
||||
# using only context
|
||||
el_pipe.cfg["incl_context"] = True
|
||||
el_pipe.cfg["incl_prior"] = False
|
||||
_add_eval_result(context_results, doc, correct_ents, el_pipe)
|
||||
|
||||
return results
|
||||
# measuring combined accuracy (prior + context)
|
||||
el_pipe.cfg["incl_context"] = True
|
||||
el_pipe.cfg["incl_prior"] = True
|
||||
_add_eval_result(combo_results, doc, correct_ents, el_pipe)
|
||||
|
||||
if baseline:
|
||||
logger.info("Counts: {}".format({k: v for k, v in sorted(counts.items())}))
|
||||
logger.info(baseline_results.report_performance("random"))
|
||||
logger.info(baseline_results.report_performance("prior"))
|
||||
logger.info(baseline_results.report_performance("oracle"))
|
||||
|
||||
if context:
|
||||
logger.info(context_results.report_metrics("context only"))
|
||||
logger.info(combo_results.report_metrics("context and prior"))
|
||||
|
||||
|
||||
def measure_baselines(data, kb):
|
||||
def _add_eval_result(results, doc, correct_ents, el_pipe):
|
||||
"""
|
||||
Measure 3 performance baselines: random selection, prior probabilities, and 'oracle' prediction for upper bound.
|
||||
Evaluate the ent.kb_id_ annotations against the gold standard.
|
||||
Only evaluate entities that overlap between gold and NER, to isolate the performance of the NEL.
|
||||
Also return a dictionary of counts by entity label.
|
||||
"""
|
||||
counts_d = dict()
|
||||
|
||||
baseline_results = BaselineResults()
|
||||
|
||||
docs = [d for d, g in data if len(d) > 0]
|
||||
golds = [g for d, g in data if len(d) > 0]
|
||||
|
||||
for doc, gold in zip(docs, golds):
|
||||
correct_entries_per_article = dict()
|
||||
for entity, kb_dict in gold.links.items():
|
||||
start, end = entity
|
||||
for gold_kb, value in kb_dict.items():
|
||||
# only evaluating on positive examples
|
||||
if value:
|
||||
offset = _offset(start, end)
|
||||
correct_entries_per_article[offset] = gold_kb
|
||||
|
||||
try:
|
||||
doc = el_pipe(doc)
|
||||
for ent in doc.ents:
|
||||
ent_label = ent.label_
|
||||
start = ent.start_char
|
||||
end = ent.end_char
|
||||
offset = _offset(start, end)
|
||||
gold_entity = correct_entries_per_article.get(offset, None)
|
||||
|
||||
gold_entity = correct_ents.get(offset, None)
|
||||
# the gold annotations are not complete so we can't evaluate missing annotations as 'wrong'
|
||||
if gold_entity is not None:
|
||||
candidates = kb.get_candidates(ent.text)
|
||||
oracle_candidate = ""
|
||||
prior_candidate = ""
|
||||
random_candidate = ""
|
||||
if candidates:
|
||||
scores = []
|
||||
pred_entity = ent.kb_id_
|
||||
results.update_metrics(ent_label, gold_entity, pred_entity)
|
||||
|
||||
for c in candidates:
|
||||
scores.append(c.prior_prob)
|
||||
if c.entity_ == gold_entity:
|
||||
oracle_candidate = c.entity_
|
||||
except Exception as e:
|
||||
logging.error("Error assessing accuracy " + str(e))
|
||||
|
||||
best_index = scores.index(max(scores))
|
||||
prior_candidate = candidates[best_index].entity_
|
||||
random_candidate = random.choice(candidates).entity_
|
||||
|
||||
current_count = counts_d.get(ent_label, 0)
|
||||
counts_d[ent_label] = current_count+1
|
||||
def _add_baseline(baseline_results, counts, doc, correct_ents, kb):
|
||||
"""
|
||||
Measure 3 performance baselines: random selection, prior probabilities, and 'oracle' prediction for upper bound.
|
||||
Only evaluate entities that overlap between gold and NER, to isolate the performance of the NEL.
|
||||
"""
|
||||
for ent in doc.ents:
|
||||
ent_label = ent.label_
|
||||
start = ent.start_char
|
||||
end = ent.end_char
|
||||
offset = _offset(start, end)
|
||||
gold_entity = correct_ents.get(offset, None)
|
||||
|
||||
baseline_results.update_baselines(
|
||||
gold_entity,
|
||||
ent_label,
|
||||
random_candidate,
|
||||
prior_candidate,
|
||||
oracle_candidate,
|
||||
)
|
||||
# the gold annotations are not complete so we can't evaluate missing annotations as 'wrong'
|
||||
if gold_entity is not None:
|
||||
candidates = kb.get_candidates(ent.text)
|
||||
oracle_candidate = ""
|
||||
prior_candidate = ""
|
||||
random_candidate = ""
|
||||
if candidates:
|
||||
scores = []
|
||||
|
||||
return baseline_results, counts_d
|
||||
for c in candidates:
|
||||
scores.append(c.prior_prob)
|
||||
if c.entity_ == gold_entity:
|
||||
oracle_candidate = c.entity_
|
||||
|
||||
best_index = scores.index(max(scores))
|
||||
prior_candidate = candidates[best_index].entity_
|
||||
random_candidate = random.choice(candidates).entity_
|
||||
|
||||
current_count = counts.get(ent_label, 0)
|
||||
counts[ent_label] = current_count+1
|
||||
|
||||
baseline_results.update_baselines(
|
||||
gold_entity,
|
||||
ent_label,
|
||||
random_candidate,
|
||||
prior_candidate,
|
||||
oracle_candidate,
|
||||
)
|
||||
|
||||
|
||||
def _offset(start, end):
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
# coding: utf-8
|
||||
"""Script to take a previously created Knowledge Base and train an entity linking
|
||||
"""Script that takes a previously created Knowledge Base and trains an entity linking
|
||||
pipeline. The provided KB directory should hold the kb, the original nlp object and
|
||||
its vocab used to create the KB, and a few auxiliary files such as the entity definitions,
|
||||
as created by the script `wikidata_create_kb`.
|
||||
|
@ -14,6 +14,7 @@ import logging
|
|||
import spacy
|
||||
from pathlib import Path
|
||||
import plac
|
||||
from tqdm import tqdm
|
||||
|
||||
from bin.wiki_entity_linking import wikipedia_processor
|
||||
from bin.wiki_entity_linking import TRAINING_DATA_FILE, KB_MODEL_DIR, KB_FILE, LOG_FORMAT, OUTPUT_MODEL_DIR
|
||||
|
@ -33,8 +34,8 @@ logger = logging.getLogger(__name__)
|
|||
dropout=("Dropout to prevent overfitting (default 0.5)", "option", "p", float),
|
||||
lr=("Learning rate (default 0.005)", "option", "n", float),
|
||||
l2=("L2 regularization", "option", "r", float),
|
||||
train_inst=("# training instances (default 90% of all)", "option", "t", int),
|
||||
dev_inst=("# test instances (default 10% of all)", "option", "d", int),
|
||||
train_articles=("# training articles (default 90% of all)", "option", "t", int),
|
||||
dev_articles=("# dev test articles (default 10% of all)", "option", "d", int),
|
||||
labels_discard=("NER labels to discard (default None)", "option", "l", str),
|
||||
)
|
||||
def main(
|
||||
|
@ -45,10 +46,13 @@ def main(
|
|||
dropout=0.5,
|
||||
lr=0.005,
|
||||
l2=1e-6,
|
||||
train_inst=None,
|
||||
dev_inst=None,
|
||||
train_articles=None,
|
||||
dev_articles=None,
|
||||
labels_discard=None
|
||||
):
|
||||
if not output_dir:
|
||||
logger.warning("No output dir specified so no results will be written, are you sure about this ?")
|
||||
|
||||
logger.info("Creating Entity Linker with Wikipedia and WikiData")
|
||||
|
||||
output_dir = Path(output_dir) if output_dir else dir_kb
|
||||
|
@ -64,44 +68,33 @@ def main(
|
|||
# STEP 1 : load the NLP object
|
||||
logger.info("STEP 1a: Loading model from {}".format(nlp_dir))
|
||||
nlp = spacy.load(nlp_dir)
|
||||
logger.info("STEP 1b: Loading KB from {}".format(kb_path))
|
||||
kb = read_kb(nlp, kb_path)
|
||||
logger.info("Original NLP pipeline has following pipeline components: {}".format(nlp.pipe_names))
|
||||
|
||||
# check that there is a NER component in the pipeline
|
||||
if "ner" not in nlp.pipe_names:
|
||||
raise ValueError("The `nlp` object should have a pretrained `ner` component.")
|
||||
|
||||
# STEP 2: read the training dataset previously created from WP
|
||||
logger.info("STEP 2: Reading training dataset from {}".format(training_path))
|
||||
logger.info("STEP 1b: Loading KB from {}".format(kb_path))
|
||||
kb = read_kb(nlp, kb_path)
|
||||
|
||||
# STEP 2: read the training dataset previously created from WP
|
||||
logger.info("STEP 2: Reading training & dev dataset from {}".format(training_path))
|
||||
train_indices, dev_indices = wikipedia_processor.read_training_indices(training_path)
|
||||
logger.info("Training set has {} articles, limit set to roughly {} articles per epoch"
|
||||
.format(len(train_indices), train_articles if train_articles else "all"))
|
||||
logger.info("Dev set has {} articles, limit set to rougly {} articles for evaluation"
|
||||
.format(len(dev_indices), dev_articles if dev_articles else "all"))
|
||||
if dev_articles:
|
||||
dev_indices = dev_indices[0:dev_articles]
|
||||
|
||||
# STEP 3: create and train an entity linking pipe
|
||||
logger.info("STEP 3: Creating and training an Entity Linking pipe for {} epochs".format(epochs))
|
||||
if labels_discard:
|
||||
labels_discard = [x.strip() for x in labels_discard.split(",")]
|
||||
logger.info("Discarding {} NER types: {}".format(len(labels_discard), labels_discard))
|
||||
else:
|
||||
labels_discard = []
|
||||
|
||||
train_data = wikipedia_processor.read_training(
|
||||
nlp=nlp,
|
||||
entity_file_path=training_path,
|
||||
dev=False,
|
||||
limit=train_inst,
|
||||
kb=kb,
|
||||
labels_discard=labels_discard
|
||||
)
|
||||
|
||||
# for testing, get all pos instances (independently of KB)
|
||||
dev_data = wikipedia_processor.read_training(
|
||||
nlp=nlp,
|
||||
entity_file_path=training_path,
|
||||
dev=True,
|
||||
limit=dev_inst,
|
||||
kb=None,
|
||||
labels_discard=labels_discard
|
||||
)
|
||||
|
||||
# STEP 3: create and train an entity linking pipe
|
||||
logger.info("STEP 3: Creating and training an Entity Linking pipe")
|
||||
|
||||
el_pipe = nlp.create_pipe(
|
||||
name="entity_linker", config={"pretrained_vectors": nlp.vocab.vectors.name,
|
||||
"labels_discard": labels_discard}
|
||||
|
@ -115,80 +108,65 @@ def main(
|
|||
optimizer.learn_rate = lr
|
||||
optimizer.L2 = l2
|
||||
|
||||
logger.info("Training on {} articles".format(len(train_data)))
|
||||
logger.info("Dev testing on {} articles".format(len(dev_data)))
|
||||
|
||||
# baseline performance on dev data
|
||||
logger.info("Dev Baseline Accuracies:")
|
||||
measure_performance(dev_data, kb, el_pipe, baseline=True, context=False)
|
||||
dev_data = wikipedia_processor.read_el_docs_golds(nlp=nlp, entity_file_path=training_path,
|
||||
dev=True, line_ids=dev_indices,
|
||||
kb=kb, labels_discard=labels_discard)
|
||||
|
||||
measure_performance(dev_data, kb, el_pipe, baseline=True, context=False, dev_limit=len(dev_indices))
|
||||
|
||||
for itn in range(epochs):
|
||||
random.shuffle(train_data)
|
||||
random.shuffle(train_indices)
|
||||
losses = {}
|
||||
batches = minibatch(train_data, size=compounding(4.0, 128.0, 1.001))
|
||||
batches = minibatch(train_indices, size=compounding(8.0, 128.0, 1.001))
|
||||
batchnr = 0
|
||||
articles_processed = 0
|
||||
|
||||
with nlp.disable_pipes(*other_pipes):
|
||||
# we either process the whole training file, or just a part each epoch
|
||||
bar_total = len(train_indices)
|
||||
if train_articles:
|
||||
bar_total = train_articles
|
||||
|
||||
with tqdm(total=bar_total, leave=False, desc='Epoch ' + str(itn)) as pbar:
|
||||
for batch in batches:
|
||||
try:
|
||||
docs, golds = zip(*batch)
|
||||
nlp.update(
|
||||
docs=docs,
|
||||
golds=golds,
|
||||
sgd=optimizer,
|
||||
drop=dropout,
|
||||
losses=losses,
|
||||
)
|
||||
batchnr += 1
|
||||
except Exception as e:
|
||||
logger.error("Error updating batch:" + str(e))
|
||||
if not train_articles or articles_processed < train_articles:
|
||||
with nlp.disable_pipes("entity_linker"):
|
||||
train_batch = wikipedia_processor.read_el_docs_golds(nlp=nlp, entity_file_path=training_path,
|
||||
dev=False, line_ids=batch,
|
||||
kb=kb, labels_discard=labels_discard)
|
||||
docs, golds = zip(*train_batch)
|
||||
try:
|
||||
with nlp.disable_pipes(*other_pipes):
|
||||
nlp.update(
|
||||
docs=docs,
|
||||
golds=golds,
|
||||
sgd=optimizer,
|
||||
drop=dropout,
|
||||
losses=losses,
|
||||
)
|
||||
batchnr += 1
|
||||
articles_processed += len(docs)
|
||||
pbar.update(len(docs))
|
||||
except Exception as e:
|
||||
logger.error("Error updating batch:" + str(e))
|
||||
if batchnr > 0:
|
||||
logging.info("Epoch {}, train loss {}".format(itn, round(losses["entity_linker"] / batchnr, 2)))
|
||||
measure_performance(dev_data, kb, el_pipe, baseline=False, context=True)
|
||||
|
||||
# STEP 4: measure the performance of our trained pipe on an independent dev set
|
||||
logger.info("STEP 4: Final performance measurement of Entity Linking pipe")
|
||||
measure_performance(dev_data, kb, el_pipe)
|
||||
|
||||
# STEP 5: apply the EL pipe on a toy example
|
||||
logger.info("STEP 5: Applying Entity Linking to toy example")
|
||||
run_el_toy_example(nlp=nlp)
|
||||
logging.info("Epoch {} trained on {} articles, train loss {}"
|
||||
.format(itn, articles_processed, round(losses["entity_linker"] / batchnr, 2)))
|
||||
# re-read the dev_data (data is returned as a generator)
|
||||
dev_data = wikipedia_processor.read_el_docs_golds(nlp=nlp, entity_file_path=training_path,
|
||||
dev=True, line_ids=dev_indices,
|
||||
kb=kb, labels_discard=labels_discard)
|
||||
measure_performance(dev_data, kb, el_pipe, baseline=False, context=True, dev_limit=len(dev_indices))
|
||||
|
||||
if output_dir:
|
||||
# STEP 6: write the NLP pipeline (now including an EL model) to file
|
||||
logger.info("STEP 6: Writing trained NLP to {}".format(nlp_output_dir))
|
||||
# STEP 4: write the NLP pipeline (now including an EL model) to file
|
||||
logger.info("Final NLP pipeline has following pipeline components: {}".format(nlp.pipe_names))
|
||||
logger.info("STEP 4: Writing trained NLP to {}".format(nlp_output_dir))
|
||||
nlp.to_disk(nlp_output_dir)
|
||||
|
||||
logger.info("Done!")
|
||||
|
||||
|
||||
def check_kb(kb):
|
||||
for mention in ("Bush", "Douglas Adams", "Homer", "Brazil", "China"):
|
||||
candidates = kb.get_candidates(mention)
|
||||
|
||||
logger.info("generating candidates for " + mention + " :")
|
||||
for c in candidates:
|
||||
logger.info(" ".join[
|
||||
str(c.prior_prob),
|
||||
c.alias_,
|
||||
"-->",
|
||||
c.entity_ + " (freq=" + str(c.entity_freq) + ")"
|
||||
])
|
||||
|
||||
|
||||
def run_el_toy_example(nlp):
|
||||
text = (
|
||||
"In The Hitchhiker's Guide to the Galaxy, written by Douglas Adams, "
|
||||
"Douglas reminds us to always bring our towel, even in China or Brazil. "
|
||||
"The main character in Doug's novel is the man Arthur Dent, "
|
||||
"but Dougledydoug doesn't write about George Washington or Homer Simpson."
|
||||
)
|
||||
doc = nlp(text)
|
||||
logger.info(text)
|
||||
for ent in doc.ents:
|
||||
logger.info(" ".join(["ent", ent.text, ent.label_, ent.kb_id_]))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
logging.basicConfig(level=logging.INFO, format=LOG_FORMAT)
|
||||
plac.call(main)
|
||||
|
|
|
@ -6,9 +6,6 @@ import bz2
|
|||
import logging
|
||||
import random
|
||||
import json
|
||||
from tqdm import tqdm
|
||||
|
||||
from functools import partial
|
||||
|
||||
from spacy.gold import GoldParse
|
||||
from bin.wiki_entity_linking import wiki_io as io
|
||||
|
@ -454,25 +451,40 @@ def _write_training_entities(outputfile, article_id, clean_text, entities):
|
|||
outputfile.write(line)
|
||||
|
||||
|
||||
def read_training(nlp, entity_file_path, dev, limit, kb, labels_discard=None):
|
||||
""" This method provides training examples that correspond to the entity annotations found by the nlp object.
|
||||
def read_training_indices(entity_file_path):
|
||||
""" This method creates two lists of indices into the training file: one with indices for the
|
||||
training examples, and one for the dev examples."""
|
||||
train_indices = []
|
||||
dev_indices = []
|
||||
|
||||
with entity_file_path.open("r", encoding="utf8") as file:
|
||||
for i, line in enumerate(file):
|
||||
example = json.loads(line)
|
||||
article_id = example["article_id"]
|
||||
clean_text = example["clean_text"]
|
||||
|
||||
if is_valid_article(clean_text):
|
||||
if is_dev(article_id):
|
||||
dev_indices.append(i)
|
||||
else:
|
||||
train_indices.append(i)
|
||||
|
||||
return train_indices, dev_indices
|
||||
|
||||
|
||||
def read_el_docs_golds(nlp, entity_file_path, dev, line_ids, kb, labels_discard=None):
|
||||
""" This method provides training/dev examples that correspond to the entity annotations found by the nlp object.
|
||||
For training, it will include both positive and negative examples by using the candidate generator from the kb.
|
||||
For testing (kb=None), it will include all positive examples only."""
|
||||
if not labels_discard:
|
||||
labels_discard = []
|
||||
|
||||
data = []
|
||||
num_entities = 0
|
||||
get_gold_parse = partial(
|
||||
_get_gold_parse, dev=dev, kb=kb, labels_discard=labels_discard
|
||||
)
|
||||
texts = []
|
||||
entities_list = []
|
||||
|
||||
logger.info(
|
||||
"Reading {} data with limit {}".format("dev" if dev else "train", limit)
|
||||
)
|
||||
with entity_file_path.open("r", encoding="utf8") as file:
|
||||
with tqdm(total=limit, leave=False) as pbar:
|
||||
for i, line in enumerate(file):
|
||||
for i, line in enumerate(file):
|
||||
if i in line_ids:
|
||||
example = json.loads(line)
|
||||
article_id = example["article_id"]
|
||||
clean_text = example["clean_text"]
|
||||
|
@ -481,16 +493,15 @@ def read_training(nlp, entity_file_path, dev, limit, kb, labels_discard=None):
|
|||
if dev != is_dev(article_id) or not is_valid_article(clean_text):
|
||||
continue
|
||||
|
||||
doc = nlp(clean_text)
|
||||
gold = get_gold_parse(doc, entities)
|
||||
if gold and len(gold.links) > 0:
|
||||
data.append((doc, gold))
|
||||
num_entities += len(gold.links)
|
||||
pbar.update(len(gold.links))
|
||||
if limit and num_entities >= limit:
|
||||
break
|
||||
logger.info("Read {} entities in {} articles".format(num_entities, len(data)))
|
||||
return data
|
||||
texts.append(clean_text)
|
||||
entities_list.append(entities)
|
||||
|
||||
docs = nlp.pipe(texts, batch_size=50)
|
||||
|
||||
for doc, entities in zip(docs, entities_list):
|
||||
gold = _get_gold_parse(doc, entities, dev=dev, kb=kb, labels_discard=labels_discard)
|
||||
if gold and len(gold.links) > 0:
|
||||
yield doc, gold
|
||||
|
||||
|
||||
def _get_gold_parse(doc, entities, dev, kb, labels_discard):
|
||||
|
|
|
@ -1308,7 +1308,7 @@ class EntityLinker(Pipe):
|
|||
for i, doc in enumerate(docs):
|
||||
if len(doc) > 0:
|
||||
# Looping through each sentence and each entity
|
||||
# This may go wrong if there are entities across sentences - because they might not get a KB ID
|
||||
# This may go wrong if there are entities across sentences - which shouldn't happen normally.
|
||||
for sent in doc.sents:
|
||||
sent_doc = sent.as_doc()
|
||||
# currently, the context is the same for each entity in a sentence (should be refined)
|
||||
|
|
Loading…
Reference in New Issue
Block a user